diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..4161c2d5a4e54e731a356656bbff8864326c7fee --- /dev/null +++ b/.azuredevops/rocm-ci.yml @@ -0,0 +1,29 @@ +resources: + repositories: + - repository: pipelines_repo + type: github + endpoint: ROCm + name: ROCm/ROCm + +variables: +- group: common +- template: /.azuredevops/variables-global.yml@pipelines_repo + +trigger: + batch: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + +pr: none + +jobs: + - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo diff --git a/.clang-tidy b/.clang-tidy index 5c2b78168747787b506b0cd6e72af595eb270670..3815c654fe91969e18403b9e74c8e01971ae804b 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,3 +1,3 @@ CheckOptions: - key: bugprone-reserved-identifier.AllowedIdentifiers - value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' + value: '__HIP_PLATFORM_HCC__;__HIP_PLATFORM_AMD__;__HIP_ROCclr__' diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..459315e58b766043355046569379ab96500a3449 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,8 @@ +* @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +# Documentation files +docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +# Header directory for Doxygen documentation +library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..0086358db1eb971c0cfa8739c27518bbc18a5ff4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/ISSUE_TEMPLATE/issue_report.yml b/.github/ISSUE_TEMPLATE/issue_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..ef6e6faa1b71b07865f1e436c7a0354d2a0f8c18 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue_report.yml @@ -0,0 +1,221 @@ +name: Issue Report +description: File a report for ROCm related issues on Linux and Windows. For issues pertaining to documentation or non-bug related, please open a blank issue located below. +title: "[Issue]: " + +body: +- type: markdown + attributes: + value: | + Thank you for taking the time to fill out this report! + + You can acquire your OS, CPU, GPU (for filling out this report) with the following commands: + + Linux: + echo "OS:" && cat /etc/os-release | grep -E "^(NAME=|VERSION=)"; + echo "CPU: " && cat /proc/cpuinfo | grep "model name" | sort --unique; + echo "GPU:" && /opt/rocm/bin/rocminfo | grep -E "^\s*(Name|Marketing Name)"; + + Windows: + (Get-WmiObject Win32_OperatingSystem).Version + (Get-WmiObject win32_Processor).Name + (Get-WmiObject win32_VideoController).Name +- type: textarea + attributes: + label: Problem Description + description: Describe the issue you encountered. + validations: + required: true +- type: input + attributes: + label: Operating System + description: What is the name and version number of the OS? + placeholder: "e.g. Ubuntu 22.04.3 LTS (Jammy Jellyfish)" + validations: + required: true +- type: input + attributes: + label: CPU + description: What CPU did you encounter the issue on? + placeholder: "e.g. AMD Ryzen 9 5900HX with Radeon Graphics" + validations: + required: true +- type: dropdown + attributes: + label: GPU + description: What GPU(s) did you encounter the issue on (you can select multiple GPUs from the list) + multiple: true + options: + - AMD Instinct MI300X + - AMD Instinct MI300A + - AMD Instinct MI300 + - AMD Instinct MI250X + - AMD Instinct MI250 + - AMD Instinct MI210 + - AMD Instinct MI100 + - AMD Instinct MI50 + - AMD Instinct MI25 + - AMD Radeon Pro V620 + - AMD Radeon Pro VII + - AMD Radeon RX 7900 XTX + - AMD Radeon VII + - AMD Radeon Pro W7900 + - AMD Radeon Pro W7800 + - AMD Radeon Pro W6800 + - AMD Radeon Pro W6600 + - AMD Radeon Pro W5500 + - AMD Radeon RX 7900 XT + - AMD Radeon RX 7600 + - AMD Radeon RX 6950 XT + - AMD Radeon RX 6900 XT + - AMD Radeon RX 6800 XT + - AMD Radeon RX 6800 + - AMD Radeon RX 6750 + - AMD Radeon RX 6700 XT + - AMD Radeon RX 6700 + - AMD Radeon RX 6650 XT + - AMD Radeon RX 6600 XT + - AMD Radeon RX 6600 + - Other + validations: + required: true +- type: input + attributes: + label: Other + description: If you selected Other, please specify +- type: dropdown + attributes: + label: ROCm Version + description: What version(s) of ROCm did you encounter the issue on? + multiple: true + options: + - ROCm 6.0.0 + - ROCm 5.7.1 + - ROCm 5.7.0 + - ROCm 5.6.1 + - ROCm 5.6.0 + - ROCm 5.5.1 + - ROCm 5.5.0 + validations: + required: true +- type: dropdown + attributes: + label: ROCm Component + description: (Optional) If this issue relates to a specific ROCm component, it can be mentioned here. + multiple: true + options: + - Other + - AMD Common Language Runtime + - AMD MIGraphX + - AMD System Management Interface + - amdgpu KCL/autoconf + - amdgpu Kernel-mode GPU Driver + - amdgpu-install + - AOMP + - AOMP Extras + - AqlProfile + - build-infra + - chelsio + - clang-ocl + - Composable Kernel + - dkms + - docker / ROCm-docker + - flang + - gpuburn + - half + - HIP + - HIP Examples + - hipBLAS + - hipBLASLt + - HIPCC + - hipCUB + - hip-examples-private + - hipFFT + - hipfort + - HIPIFY + - hipRAND + - hipSOLVER + - hipSPARSE + - hipSPARSELt + - hipTensor + - hip-tests + - HSA Runtime + - infrastructure + - jenkins-utils + - libdrm + - Linux BPI packaging framework + - llvm-project + - Mesa + - meta + - MIOpen + - MIVisionX + - ml-framework-ci + - MLSEQA_TestRepo + - OpenCL API C++ Bindings + - OpenCL API Headers + - OpenCL Conformance Test Suite + - OpenCL ICD Loader + - perftest-p2p + - prototype + - RCCL + - rccl-rdma-sharp-plugins + - rocALUTION + - rocBLAS + - ROCdbgapi + - ROCdebug-agent + - rocFFT + - ROCgdb + - ROCK + - ROCm Documentation/Website + - ROCm Data Center Tool + - ROCm Examples + - ROCm for Windows + - ROCm Performance Primitives + - ROCm System Management Interface Library + - ROCm Thrust + - ROCm Validation Suite + - rocm_bandwidth_test + - rocm-cmake + - rocm-core + - rocm-docs-core + - rocminfo + - rocMLIR + - rocmtools + - rocPRIM + - rocprofiler + - rocRAND + - ROCR-Runtime + - rocSOLVER + - rocSPARSE + - roctracer + - ROCT-Thunk-Interface + - rocWMMA + - Tensile + - umr + - ibv_rc_pingpong-amd + - mellanox + - mpitest + - Pytorch + - Tensorflow + - APEX + - torchvision + - Magma +- type: textarea + attributes: + label: Steps to Reproduce + description: (Optional) Detailed steps to reproduce the issue. + validations: + required: false + +- type: textarea + attributes: + label: (Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support + description: The output of rocminfo --support could help to better address the problem. + validations: + required: false + +- type: textarea + attributes: + label: Additional Information + description: (Optional) Any additional information that is relevant, e.g. relevant environment variables, dockerfiles, log files, dmesg output (on Linux), etc. + validations: + required: false diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 276690bd4f96f50eb6a5121d237c8c9bcb80224d..0e0a252eb6ee9d8508d2b21f42df3c77a1739832 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,3 +10,9 @@ updates: open-pull-requests-limit: 10 schedule: interval: "daily" + labels: + - "documentation" + - "dependencies" + - "ci:docs-only" + reviewers: + - "samjwu" diff --git a/.gitignore b/.gitignore index 4a668630bcf6daaf79ab2df6f8ebe98860d47c8f..e18b2212c45696fbd20332062ed7450b40a54993 100644 --- a/.gitignore +++ b/.gitignore @@ -61,8 +61,15 @@ _images/ _static/ _templates/ _toc.yml -docBin/ _doxygen/ -# pycache +# JetBrains IDE +.idea/ +cmake-build*/ +build*/ + +# Python virtualenv +.venv/ + +# Python cache __pycache__/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 5f50df2525d2948bbabe36184f23310d47339d8e..b3299fa4e8830d534931a29faf22d3a9020de951 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,11 +3,6 @@ version: 2 -build: - os: ubuntu-22.04 - tools: - python: "3.8" - sphinx: configuration: docs/conf.py @@ -16,3 +11,8 @@ formats: [htmlzip, pdf, epub] python: install: - requirements: docs/sphinx/requirements.txt + +build: + os: ubuntu-22.04 + tools: + python: "3.10" diff --git a/CHANGELOG.md b/CHANGELOG.md index 31f129b58159076c3d3027e3c36ccc885e4b1baf..dec6334cf5b596aacde281000685d7a081b986a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,32 +1,87 @@ -# Change Log for Composable Kernel +# Changelog for Composable Kernel -Full documentation for Composable Kernel is not yet available. +Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## CK 0.2.0 for ROCm 5.5.0 +## Composable Kernel 1.1.0 for ROCm 6.1.0 -### Fixed -- Fixed a bug in 6-dimensional kernels (#555). -- Fixed grouped ConvBwdWeight test case failure (#524). +### Additions + +* Added generic instances for GEMM XDL operations (#1161) +* Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) +* Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124) + +### Optimizations + +* New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135) + +### Fixes + +* Reduced the build time for most GPU architectures (#1084) +* Fixed some conversion issues for fp8 data type (#1099) + +### Changes + +None + +### Known issues + +None + +## Composable Kernel 1.1.0 for ROCm 6.0.0 + +### Fixes + +* Fixed a hazard associated with inline v_dot (#808) +* Fixed two bugs in grouped convolution backward data without K padding (#848 #876) + +### Optimizations + +None + +### Additions + +* Added an image to a column kernel (#867) +* Added a column to an image kernel (#930) +* Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) +* Grouped convolution support for small K and C (#822 #879 #897) +* Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) +* Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) +* Support for Batched GEMM DL (#732) + +### Changes + +* Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) + +## Composable Kernel 0.2.0 for ROCm 5.7.0 + +### Fixes + +* Fixed a bug in 6-dimensional kernels (#555) +* Fixed a test case failure with grouped convolution backward weight (#524) ### Optimizations -- Improve proformance of normalization kernel - -### Added -- Added new cmake flag "DL_KERNELS" must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances. -- Added new cmake flag "DTYPES" which could be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instance of select data types. -- Added new cmake flag "INSTANCES_ONLY" which will only build CK library and instances without the tests, examples, or profiler. -- Added new feature: if GPU_TARGETS is not set on cmake command line, CK will be built for all targets supported by compiler. -- Added support on MI300A/MI300X. -- Added support on NAVI3x. -- Added user tutorial (#563). -- Added more instances for irregular GEMM sizes (#560). -- Added inter-wave consumer-producer programming model for GEMM kernels (#310). -- Added multi-D GEMM client APIs (#534). -- Added multi-embeddings support (#542). -- Added Navi3x blockwise GEMM and real GEMM support (#541). -- Added Navi grouped ConvBwdWeight support (#505). -- Added MaxPool, AvgPool forward (#815). -- Added MaxPool backward (#750). - -### Changed -- Changed ... + +* Improved the performance of the normalization kernel + +### Additions + +* New CMake flags: + * "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances + * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types + * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler +* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler +* Support for MI300A/MI300X +* Support for AMD RDNA 3 +* New user tutorial (#563) +* Additional instances for irregular GEMM sizes (#560) +* New inter-wave consumer-producer programming model for GEMM kernels (#310) +* GEMM with support multiple elementwise fusions (multi-D) (#534) +* Multi-embeddings support (#542) +* AMD RDNA 3 blockwise GEMM and real GEMM support (#541) +* AMD RDNA grouped convolution backward weight support (#505) +* MaxPool and AvgPool forward (#815); MaxPool backward (#750) + +### Changes + +None diff --git a/CITATION.cff b/CITATION.cff index d35fe9e5870f0a538c1775efd239dc729e11ab88..3813d63812566760017b6f2d3bfeec0ecd11f884 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -59,9 +59,9 @@ authors: family-names: Zhou - given-names: Jianfeng family-names: Yan -repository-code: 'https://github.com/ROCmSoftwarePlatform/composable_kernel' +repository-code: 'https://github.com/ROCm/composable_kernel' abstract: Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for Machine Learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel progarmming languages, like HIP C++. keywords: - 'CK, Composable Kernel, Tensor Coordinate Transformation' license: MIT -license-url: https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE +license-url: https://github.com/ROCm/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ef5e1ca073003a8a130743541f8fa63b8505cc4..f44c0cbde6d76307c7d53c8cc4496a507c824adc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,48 @@ cmake_minimum_required(VERSION 3.14) +if(POLICY CMP0140) + # policies CMP0140 not known to CMake until 3.25 + cmake_policy(SET CMP0140 NEW) +endif() + +get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) + +# This has to be initialized before the project() command appears +# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE +if(_GENERATOR_IS_MULTI_CONFIG) + set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING + "Available build types (configurations) on multi-config generators") +else() + set(CMAKE_BUILD_TYPE Release CACHE STRING + "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.") +endif() + +# Default installation path +if(NOT WIN32) + set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") +endif() set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version}) +project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) +include(CTest) + +# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" +# CK Codegen requires dataclass which is added in Python 3.7 +# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 +if(NOT CK_USE_ALTERNATIVE_PYTHON) + find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED) +else() + message("Using alternative python version") + set(EXTRA_PYTHON_PATH) + # this is overly restrictive, we may need to be more flexible on the following + string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}") + message("alternative python path is: ${EXTRA_PYTHON_PATH}") + find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) + add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}") + set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}") + set(PYTHON_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}") + set(ENV{LD_LIBRARY_PATH} "${EXTRA_PYTHON_PATH}/lib:$ENV{LD_LIBRARY_PATH}") +endif() list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") @@ -16,6 +56,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP8) set(CK_ENABLE_FP8 "ON") endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") + endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) set(CK_ENABLE_FP16 "ON") @@ -34,29 +78,36 @@ if (DTYPES) endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) - set(CK_ENABLE_ALL_DTYPES "ON") + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8) + set(CK_ENABLE_INT8 "ON") + set(CK_ENABLE_FP16 "ON") + set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_FP64 "ON") + set(CK_ENABLE_BF16 "ON") + set(CK_ENABLE_FP8 "ON") + set(CK_ENABLE_BF8 "ON") endif() +#for f8/bf8_t type +add_compile_options(-Wno-bit-int-extension) +add_compile_options(-Wno-pass-failed) +add_compile_options(-Wno-switch-default) + if(DL_KERNELS) add_definitions(-DDL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") endif() - -if(INSTANCES_ONLY) - add_definitions(-DINSTANCES_ONLY) - set(CK_ENABLE_INSTANCES_ONLY "ON") +option(CK_USE_CODEGEN "Enable codegen library" OFF) +if(CK_USE_CODEGEN) + add_definitions(-DCK_USE_CODEGEN) endif() -# CK config file to record supported datatypes, etc. -configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h") +include(getopt) # CK version file to record release version as well as git commit hash find_package(Git REQUIRED) execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) -configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h") - -enable_testing() +configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h) set(ROCM_SYMLINK_LIBS OFF) find_package(ROCM REQUIRED PATHS /opt/rocm) @@ -72,45 +123,144 @@ include(TargetFlags) rocm_setup_version(VERSION ${version}) -list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}") message("GPU_TARGETS= ${GPU_TARGETS}") +message("GPU_ARCHS= ${GPU_ARCHS}") +if(GPU_ARCHS) + #disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package + unset(GPU_TARGETS CACHE) + unset(AMDGPU_TARGETS CACHE) +endif() +if(GPU_TARGETS) + set(USER_GPU_TARGETS 1) +else() + set(USER_GPU_TARGETS 0) +endif() +find_package(hip REQUIRED) +# No assumption that HIP kernels are launched with uniform block size for backward compatibility +# SWDEV-413293 and https://reviews.llvm.org/D155213 +math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") +message("hip_version_flat=${hip_VERSION_FLAT}") message("checking which targets are supported") -#This is the list of targets to be used in case GPU_TARGETS is not set on command line -#These targets will be filtered and only supported ones will be used -#Setting GPU_TARGETS on command line will override this list -if(NOT PROFILER_ONLY) - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") -else() - add_definitions(-DPROFILER_ONLY) - if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11") - endif() - if(GPU_ARCH MATCHES "gfx9") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942") - elseif(GPU_ARCH MATCHES "gfx10") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") - elseif(GPU_ARCH MATCHES "gfx11") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") +#In order to build just the CK library (without tests and examples) for all supported GPU targets +#use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" +#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. +# +#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures. +if(NOT ENABLE_ASAN_PACKAGING) + if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) + # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11") + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") endif() +else() + #build CK only for xnack-supported targets when using ASAN + set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+") endif() -message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") +#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list +#otherwise, if user set GPU_TARGETS, use that set of targets +if(GPU_ARCHS) + set(CK_GPU_TARGETS ${GPU_ARCHS}) +else() + if(USER_GPU_TARGETS) + set(CK_GPU_TARGETS ${GPU_TARGETS}) + endif() +endif() +#if the user did not set GPU_TARGETS, delete whatever was set by HIP package +if(NOT USER_GPU_TARGETS) + set(GPU_TARGETS "") +endif() +#make sure all the targets on the list are actually supported by the current compiler +rocm_check_target_ids(SUPPORTED_GPU_TARGETS + TARGETS ${CK_GPU_TARGETS}) -set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " ") +message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") -if(GPU_TARGETS) - message("Building CK for the following targets: ${GPU_TARGETS}") -else() - message("Building CK for the following targets: ${AMDGPU_TARGETS}") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") + message("Enabling XDL instances") + add_definitions(-DCK_USE_XDL) +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") + message("Enabling FP8 gemms in ckProfiler") + add_definitions(-DCK_USE_GFX94) +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message("Enabling WMMA instances") + add_definitions(-DCK_USE_WMMA) +endif() +option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) +if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) + add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) +endif() + +# CK config file to record supported datatypes, etc. +configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) + +if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302) + check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK) + if(HAS_NO_OFFLOAD_UNIFORM_BLOCK) + message("Adding the fno-offload-uniform-block compiler flag") + add_compile_options(-fno-offload-uniform-block) + endif() +endif() +if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) + check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION) + if(HAS_LSR_DROP_SOLUTION) + message("Adding the lsr-drop-solution=1 compiler flag") + add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") + endif() +endif() +if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) + check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) + if(HAS_ENABLE_POST_MISCHED) + message("Adding the enable-post-misched=0 compiler flag") + add_compile_options("SHELL: -mllvm -enable-post-misched=0") + endif() +endif() +set(check-coerce) +check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) +if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) + message("Adding the amdgpu-coerce-illegal-types=1") + add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1") +endif() +if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) + message("Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false") + add_compile_options("SHELL: -mllvm -amdgpu-early-inline-all=true") + add_compile_options("SHELL: -mllvm -amdgpu-function-calls=false") +endif() +# +# Seperate linking jobs from compiling +# Too many concurrent linking jobs can break the build +# Copied from LLVM +set(CK_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") +if(CMAKE_GENERATOR MATCHES "Ninja") + if(CK_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${CK_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) + endif() +elseif(CK_PARALLEL_LINK_JOBS) + message(WARNING "Job pooling is only available with Ninja generators.") +endif() +# Similar for compiling +set(CK_PARALLEL_COMPILE_JOBS "" CACHE STRING + "Define the maximum number of concurrent compile jobs (Ninja only).") +if(CMAKE_GENERATOR MATCHES "Ninja") + if(CK_PARALLEL_COMPILE_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS compile_job_pool=${CK_PARALLEL_COMPILE_JOBS}) + set(CMAKE_JOB_POOL_COMPILE compile_job_pool) + endif() +elseif(CK_PARALLEL_COMPILE_JOBS) + message(WARNING "Job pooling is only available with Ninja generators.") endif() -option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) -option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) + +option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) +option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -118,10 +268,10 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() -if(USE_OPT_NAVI3X) +if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") + message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() ## Threads @@ -130,11 +280,26 @@ find_package(Threads REQUIRED) link_libraries(Threads::Threads) ## C++ -enable_language(CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") + +# https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_macros.html +# _GLIBCXX_ASSERTIONS +# Undefined by default. When defined, enables extra error checking in the form of +# precondition assertions, such as bounds checking in strings and null pointer +# checks when dereferencing smart pointers +option(USE_GLIBCXX_ASSERTIONS "Turn on additional c++ library checks." OFF) +if(USE_GLIBCXX_ASSERTIONS) + add_compile_options(-Wp,-D_GLIBCXX_ASSERTIONS) +endif() + +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") option(CK_BUILD_JIT_LIB, "Only build the CK JIT Helper Library" OFF) if (NOT CK_BUILD_JIT_LIB) @@ -156,61 +321,32 @@ if (NOT CK_BUILD_JIT_LIB) add_compile_options(-Wno-bit-int-extension) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() +endif() - if(USE_OPT_NAVI3X) - add_compile_options(-mcumode) - add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") - endif() - - ## Threads - set(THREADS_PREFER_PTHREAD_FLAG ON) - find_package(Threads REQUIRED) - link_libraries(Threads::Threads) - - ## OpenMP - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - # workaround issue hipcc in rocm3.5 cannot find openmp - set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") - set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") - set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") - set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) - set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) - set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) - else() - find_package(OpenMP REQUIRED) - endif() - - message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") - message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") - message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") - message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") - - link_libraries(${OpenMP_gomp_LIBRARY}) - link_libraries(${OpenMP_pthread_LIBRARY}) - - ## HIP - find_package(HIP REQUIRED) - # Override HIP version in config.h, if necessary. - # The variables set by find_package() can't be overwritten, - # therefore let's use intermediate variables. - set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}") - set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}") - set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}") - if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR ) - set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}") - message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}") - endif() - if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR ) - set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}") - message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}") - endif() - if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) - set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}") - message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") - endif() - message(STATUS "Build with HIP ${HIP_VERSION}") - link_libraries(hip::device) +## HIP +# Override HIP version in config.h, if necessary. +# The variables set by find_package() can't be overwritten, +# therefore let's use intermediate variables. +set(CK_HIP_VERSION_MAJOR "${HIP_VERSION_MAJOR}") +set(CK_HIP_VERSION_MINOR "${HIP_VERSION_MINOR}") +set(CK_HIP_VERSION_PATCH "${HIP_VERSION_PATCH}") +if( DEFINED CK_OVERRIDE_HIP_VERSION_MAJOR ) + set(CK_HIP_VERSION_MAJOR "${CK_OVERRIDE_HIP_VERSION_MAJOR}") + message(STATUS "CK_HIP_VERSION_MAJOR overriden with ${CK_OVERRIDE_HIP_VERSION_MAJOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_MINOR ) + set(CK_HIP_VERSION_MINOR "${CK_OVERRIDE_HIP_VERSION_MINOR}") + message(STATUS "CK_HIP_VERSION_MINOR overriden with ${CK_OVERRIDE_HIP_VERSION_MINOR}") +endif() +if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) + set(CK_HIP_VERSION_PATCH "${CK_OVERRIDE_HIP_VERSION_PATCH}") + message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") +endif() +message(STATUS "Build with HIP ${HIP_VERSION}") +link_libraries(hip::device) +if(CK_hip_VERSION VERSION_GREATER_EQUAL 6.0.23494) + add_compile_definitions(__HIP_PLATFORM_AMD__=1) +else() add_compile_definitions(__HIP_PLATFORM_HCC__=1) endif() @@ -222,7 +358,6 @@ if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\ # Enable tidy on hip elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") set(CK_TIDY_ERRORS ALL) -endif() include(ClangTidy) @@ -378,6 +513,15 @@ if (NOT CK_BUILD_JIT_LIB) endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + add_compile_options(-fcolor-diagnostics) + endif() + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9) + add_compile_options(-fdiagnostics-color=always) + endif() + + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) set(CK_DEVICE_INSTANCES) @@ -387,32 +531,28 @@ if (NOT CK_BUILD_JIT_LIB) set(cmake_instance) file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) set(add_inst 0) - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") - #message("fp8 instance found!") + if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") - #message("fp16 instance found!") + if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") - #message("fp32 instance found!") + if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") - #message("fp64 instance found!") + if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") - #message("bf16 instance found!") + if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") - #message("int8 instance found!") + if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") set(add_inst 1) endif() if(NOT "${cmake_instance}" MATCHES "DTYPES") - #message("instance should be built for all types!") set(add_inst 1) endif() if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) @@ -421,6 +561,49 @@ if (NOT CK_BUILD_JIT_LIB) ENDIF() ENDFOREACH() + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") + file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) + set(CK_DEVICE_INSTANCES) + FOREACH(subdir_path ${dir_list}) + set(target_dir) + IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") + set(cmake_instance) + file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) + set(add_inst 0) + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") + #message("fp8 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") + #message("fp16 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") + #message("fp32 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") + #message("fp64 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") + #message("bf16 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") + #message("int8 instance found!") + set(add_inst 1) + endif() + if(NOT "${cmake_instance}" MATCHES "DTYPES") + #message("instance should be built for all types!") + set(add_inst 1) + endif() + if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) + list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) + endif() + ENDIF() + ENDFOREACH() + add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) add_subdirectory(library) @@ -452,6 +635,31 @@ if (NOT CK_BUILD_JIT_LIB) add_subdirectory(profiler) endif() endif() + if(NOT GPU_ARCHS AND USER_GPU_TARGETS) + rocm_package_setup_component(tests + LIBRARY_NAME composablekernel + PACKAGE_NAME tests # Prevent -static suffix on package name + ) + + rocm_package_setup_component(examples + LIBRARY_NAME composablekernel + PACKAGE_NAME examples + ) + add_subdirectory(example) + if(BUILD_TESTING) + add_subdirectory(test) + endif() + endif() + + rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckprofiler + ) + add_subdirectory(profiler) + + if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) + add_subdirectory(codegen) + endif() else() rocm_package_setup_component(jit_library LIBRARY_NAME composablekernel @@ -461,6 +669,8 @@ else() add_subdirectory(test) endif() + + #Create an interface target for the include only files and call it "composablekernels" include(CMakePackageConfigHelpers) @@ -483,7 +693,7 @@ rocm_install(FILES ) # Install CK version and configuration files -install(FILES +rocm_install(FILES ${PROJECT_BINARY_DIR}/include/ck/version.h ${PROJECT_BINARY_DIR}/include/ck/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/ diff --git a/Config.cmake.in b/Config.cmake.in index 0a19915565f1b2d719f8c0687689daf31f5fc8c4..b5fe3b847315d8882d71e541156ab40b7cb1e26a 100644 --- a/Config.cmake.in +++ b/Config.cmake.in @@ -1,6 +1,6 @@ @PACKAGE_INIT@ -set(_composable_kernel_supported_components device_operations utility jit_library) +set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility jit_library) foreach(_comp ${composable_kernel_FIND_COMPONENTS}) if(NOT _comp IN_LIST _composable_kernel_supported_components) diff --git a/Dockerfile b/Dockerfile index e479268f48050957fabdaca7cfc5d162610f4fc1..791d1d9f3ab8143e74ead44187e7a8bb9b764f8f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,9 @@ FROM ubuntu:20.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=5.6 +ARG ROCMVERSION=6.2 ARG compiler_version="" ARG compiler_commit="" +ARG CK_SCCACHE="" RUN set -xe @@ -16,72 +17,88 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN wget https://repo.radeon.com/amdgpu-install/5.6/ubuntu/focal/amdgpu-install_5.6.50600-1_all.deb --no-check-certificate -RUN apt-get update && \ -DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ - ./amdgpu-install_5.6.50600-1_all.deb - -RUN if [ "$ROCMVERSION" != "5.7" ]; then \ +RUN if [ "$ROCMVERSION" != "6.3" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.2.60200-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.2.60200-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "" ] || [ "$compiler_version" = "amd-stg-open" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ - apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ - amdgpu-repo --amdgpu-build=1609671 --rocm-build=compute-rocm-npi-mi300/1354; \ - elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ - apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.7 rel-19 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1637781; \ + elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=2074281; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN amdgpu-install -y --usecase=rocm --no-dkms +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi + # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ - ccache \ cmake \ git \ hip-rocclr \ + iputils-ping \ jq \ libelf-dev \ libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ + net-tools \ pkg-config \ python \ python3 \ python3-dev \ python3-pip \ + redis \ + rocm-llvm-dev \ sshpass \ + stunnel \ software-properties-common \ vim \ nano \ zlib1g-dev \ + zip \ openssh-server \ clang-format-12 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -#Install latest version of cmake +# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 +RUN if [ "$ROCMVERSION" = "6.1" ]; then \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \ + fi +# Update the cmake to version 3.27.5 +RUN pip install --upgrade cmake==3.27.5 + +#Install latest ccache +RUN git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install + +#Install ninja build tracing tools RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip RUN gunzip /usr/local/bin/ninja.gz RUN chmod a+x /usr/local/bin/ninja RUN git clone https://github.com/nico/ninjatracing.git -RUN apt purge --auto-remove -y cmake -RUN apt update -RUN apt install -y software-properties-common lsb-release -RUN apt clean all -RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null -RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" -RUN apt install -y kitware-archive-keyring -RUN rm /etc/apt/trusted.gpg.d/kitware.gpg -RUN apt install -y cmake + +#Install latest cppcheck +RUN git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . +WORKDIR / # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer @@ -96,9 +113,9 @@ ARG PREFIX=/opt/rocm RUN pip3 install --upgrade pip RUN pip3 install sqlalchemy==1.4.46 RUN pip3 install pymysql -RUN pip3 install pandas +RUN pip3 install pandas==2.0.3 RUN pip3 install setuptools-rust -RUN pip3 install sshtunnel +RUN pip3 install sshtunnel==0.4.0 # Setup ubsan environment to printstacktrace ENV UBSAN_OPTIONS=print_stacktrace=1 @@ -107,7 +124,7 @@ ENV LANG=C.UTF-8 RUN groupadd -f render # Install the new rocm-cmake version -RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ +RUN git clone -b master https://github.com/ROCm/rocm-cmake.git && \ cd rocm-cmake && mkdir build && cd build && \ cmake .. && cmake --build . && cmake --build . --target install @@ -118,22 +135,26 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ - git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ +ARG DISABLE_CACHE=0 + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ - git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi +#clean-up the deb package +RUN sh -c "rm -rf amdgpu-install*" #ENV HIP_CLANG_PATH='/llvm-project/build/bin' #RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" diff --git a/Jenkinsfile b/Jenkinsfile index 668d9a613bf3c0120076875b3882941530b66ae2..b79b2045b0750b2244d643df72bbeaaebd613f43 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,5 +1,5 @@ def rocmnode(name) { - return 'rocmtest && miopen && ' + name + return '(rocmtest || miopen) && (' + name + ')' } def show_node_info() { @@ -7,6 +7,7 @@ def show_node_info() { echo "NODE_NAME = \$NODE_NAME" lsb_release -sd uname -r + cat /sys/module/amdgpu/version ls /opt/ -la """ } @@ -33,7 +34,11 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "5.7"){ + if (params.USE_CUSTOM_DOCKER != ""){ + img = "${params.USE_CUSTOM_DOCKER}" + } + else{ + if (params.ROCMVERSION != "6.3"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -61,14 +66,15 @@ def getDockerImageName(){ } } } + } return img } def check_host() { - if ("${env.CK_CCACHE}" != "null"){ - def CCACHE_SERVER="${env.CK_CCACHE.split(':')[0]}" - echo "ccache server: ${CCACHE_SERVER}" - sh '''ping -c 1 -p 6379 "${CCACHE_SERVER}" | echo $? > tmp.txt''' + if ("${env.CK_SCCACHE}" != "null"){ + def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" + echo "sccache server: ${SCCACHE_SERVER}" + sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" return (output != "0") @@ -80,53 +86,38 @@ def check_host() { def build_compiler(){ def compiler - if (params.BUILD_COMPILER == "hipcc"){ - compiler = '/opt/rocm/bin/hipcc' - } - else{ - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ - compiler = "/llvm-project/build/bin/clang++" - } - else{ - compiler = "/opt/rocm/llvm/bin/clang++" - } - } + compiler = "${params.BUILD_COMPILER}" return compiler } def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 - def prefixpath = conf.get("prefixpath", "/opt/rocm") // prefix:/opt/rocm + def prefixpath = conf.get("prefixpath", "/opt/rocm") def no_cache = conf.get("no_cache", false) - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - echo "ccache server: ${env.CK_CCACHE}" - if(env.CK_CCACHE) - { - if(check_host()) - { - echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" - } - else - { - echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" - } - dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " - env.CCACHE_DIR = """/tmp/ccache_store""" - env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" - } + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' " if(no_cache) { dockerArgs = dockerArgs + " --no-cache " } echo "Docker Args: ${dockerArgs}" - def image = getDockerImageName() + def image + if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using legacy docker: ${image}" + } + else{ + image = getDockerImageName() + echo "Using default docker: ${image}" + } //Check if image exists def retimage try { echo "Pulling down image: ${image}" retimage = docker.image("${image}") - retimage.pull() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.pull() + } } catch(Exception ex) { @@ -141,41 +132,33 @@ def buildDocker(install_prefix){ checkout scm def image_name = getDockerImageName() echo "Building Docker for ${image_name}" - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - echo "ccache server: ${env.CK_CCACHE}" - if(env.CK_CCACHE) - { - if(check_host()) - { - echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" - } - else - { - echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" - } - dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " - env.CCACHE_DIR = """/tmp/ccache_store""" - env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' " + if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + dockerArgs = dockerArgs + " --no-cache " } - echo "Build Args: ${dockerArgs}" try{ if(params.BUILD_DOCKER){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs + ' .') - retimage.push() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.push() + } + sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' } else{ echo "Checking for image: ${image_name}" sh "docker manifest inspect --insecure ${image_name}" - echo "Image: ${image_name} found!! Skipping building image" + echo "Image: ${image_name} found! Skipping building image" } } catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - retimage.push() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.push() + } } } @@ -210,19 +193,18 @@ def cmake_build(Map conf=[:]){ } else{ setup_args = ' -DBUILD_DEV=On' + setup_args } + if (params.DL_KERNELS){ + setup_args = setup_args + " -DDL_KERNELS=ON " + } if(build_type_debug){ setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args }else{ setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args } - if(env.CK_CCACHE) - { - setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER='ccache' -DCMAKE_C_COMPILER_LAUNCHER='ccache' " + setup_args - } - echo "ccache server: ${env.CK_CCACHE}" def pre_setup_cmd = """ + #!/bin/bash echo \$HSA_ENABLE_SDMA ulimit -c unlimited rm -rf build @@ -231,26 +213,126 @@ def cmake_build(Map conf=[:]){ mkdir install cd build """ - def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + def invocation_tag="" + if (setup_args.contains("gfx12")){ + invocation_tag="gfx12" + } + if (setup_args.contains("gfx11")){ + invocation_tag="gfx11" + } + if (setup_args.contains("gfx10")){ + invocation_tag="gfx10" + } + if (setup_args.contains("gfx90")){ + invocation_tag="gfx90" + } + if (setup_args.contains("gfx94")){ + invocation_tag="gfx94" + } + echo "invocation tag: ${invocation_tag}" + def redis_pre_setup_cmd = pre_setup_cmd + if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") { + redis_pre_setup_cmd = pre_setup_cmd + """ + #!/bin/bash + export ROCM_PATH=/opt/rocm + export SCCACHE_ENABLED=true + export SCCACHE_LOG_LEVEL=debug + export SCCACHE_IDLE_TIMEOUT=14400 + export COMPILERS_HASH_DIR=/tmp/.sccache + export SCCACHE_BIN=/usr/local/.cargo/bin/sccache + export SCCACHE_EXTRAFILES=/tmp/.sccache/rocm_compilers_hash_file + export SCCACHE_REDIS="redis://${env.CK_SCCACHE}" + echo "connect = ${env.CK_SCCACHE}" >> ../script/redis-cli.conf + export SCCACHE_C_CUSTOM_CACHE_BUSTER="${invocation_tag}" + echo \$SCCACHE_C_CUSTOM_CACHE_BUSTER + stunnel ../script/redis-cli.conf + ../script/sccache_wrapper.sh --enforce_redis + """ + try { + def cmd1 = conf.get("cmd1", """ + ${redis_pre_setup_cmd} + """) + sh cmd1 + setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args + } + catch(Exception err){ + echo "could not connect to redis server: ${err.getMessage()}. will not use sccache." + def cmd2 = conf.get("cmd2", """ + ${pre_setup_cmd} + """) + sh cmd2 + } + } + else{ + def cmd3 = conf.get("cmd3", """ + ${pre_setup_cmd} + """) + sh cmd3 + } + // reduce parallelism when compiling, clang uses too much memory def nt = nthreads() - def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") + def cmd + def setup_cmd + def build_cmd def execute_cmd = conf.get("execute_cmd", "") - - def cmd = conf.get("cmd", """ - ${pre_setup_cmd} + if(!setup_args.contains("NO_CK_BUILD")){ + if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + echo "running ninja build trace" + setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake -G Ninja ${setup_args} .. ") + build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") + } + else{ + setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}") + } + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} ${execute_cmd} """) + } + else{ + cmd = conf.get("cmd", """ + ${execute_cmd} + """) + } echo cmd - sh cmd + + dir("build"){ + //build CK + sh cmd + //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set + if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ + if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" + archiveArtifacts "ck_build_trace.json" + sh "ninja test" + } + else{ + sh "make check" + } + } + } // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) { + if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } + if (params.RUN_CK_TILE_FMHA_TESTS){ + try{ + archiveArtifacts "perf_fmha_fwd_*.log" + archiveArtifacts "perf_fmha_bwd_*.log" + stash name: "perf_fmha_fwd_gfx942.log" + stash name: "perf_fmha_bwd_gfx942.log" + stash name: "perf_fmha_fwd_gfx90a.log" + stash name: "perf_fmha_bwd_gfx90a.log" + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } } def buildHipClangJob(Map conf=[:]){ @@ -259,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - def image = getDockerImageName() + def image + if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using legacy docker: ${image}" + } + else{ + image = getDockerImageName() + echo "Using default docker: ${image}" + } def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group @@ -267,19 +357,23 @@ def buildHipClangJob(Map conf=[:]){ if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } - def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 5, unit: 'HOURS') + timeout(time: 48, unit: 'HOURS') { cmake_build(conf) } @@ -322,21 +416,23 @@ def runCKProfiler(Map conf=[:]){ if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " - } def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + sh 'rocminfo | tee rocminfo.log' + if ( !runShell('grep -n "gfx" rocminfo.log') ){ throw new Exception ("GPU not found") } else{ @@ -349,26 +445,10 @@ def runCKProfiler(Map conf=[:]){ echo "The job was cancelled or aborted" throw e } - catch(Exception ex) { - retimage = docker.build("${image}", dockerArgs + " --no-cache .") - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ - throw new Exception ("GPU not found") - } - else{ - echo "GPU is OK" - } - } - } - } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { - //cmake_build(conf) - //instead of building, just unstash the ckProfiler and install it sh """ rm -rf build mkdir build @@ -380,31 +460,34 @@ def runCKProfiler(Map conf=[:]){ dir("script"){ if (params.RUN_FULL_QA){ - sh "./run_full_performance_tests.sh 1 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" archiveArtifacts "perf_gemm.log" archiveArtifacts "perf_resnet50_N256.log" archiveArtifacts "perf_resnet50_N4.log" archiveArtifacts "perf_batched_gemm.log" archiveArtifacts "perf_grouped_gemm.log" - archiveArtifacts "perf_conv_fwd.log" - archiveArtifacts "perf_conv_bwd_data.log" + archiveArtifacts "perf_grouped_conv_fwd.log" + archiveArtifacts "perf_grouped_conv_bwd_data.log" + archiveArtifacts "perf_grouped_conv_bwd_weight.log" archiveArtifacts "perf_gemm_bilinear.log" archiveArtifacts "perf_reduction.log" - archiveArtifacts "perf_splitK_gemm_verify.log" archiveArtifacts "perf_splitK_gemm.log" archiveArtifacts "perf_onnx_gemm.log" + archiveArtifacts "perf_mixed_gemm.log" // stash perf files to master stash name: "perf_gemm.log" stash name: "perf_resnet50_N256.log" stash name: "perf_resnet50_N4.log" stash name: "perf_batched_gemm.log" stash name: "perf_grouped_gemm.log" - stash name: "perf_conv_fwd.log" - stash name: "perf_conv_bwd_data.log" + stash name: "perf_grouped_conv_fwd.log" + stash name: "perf_grouped_conv_bwd_data.log" + stash name: "perf_grouped_conv_bwd_weight.log" stash name: "perf_gemm_bilinear.log" stash name: "perf_reduction.log" stash name: "perf_splitK_gemm.log" stash name: "perf_onnx_gemm.log" + stash name: "perf_mixed_gemm.log" //we will process results on the master node } else{ @@ -445,9 +528,19 @@ def Build_CK(Map conf=[:]){ show_node_info() env.HSA_ENABLE_SDMA=0 + env.DOCKER_BUILDKIT=1 checkout scm - def image = getDockerImageName() + def image + if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using legacy docker: ${image}" + } + else{ + image = getDockerImageName() + echo "Using default docker: ${image}" + } + def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group @@ -456,29 +549,32 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } + if(params.BUILD_LEGACY_OS){ + dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' " + } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME def retimage - def navi_node = 0 - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + sh 'rocminfo | tee rocminfo.log' + if ( !runShell('grep -n "gfx" rocminfo.log') ){ throw new Exception ("GPU not found") } else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ - navi_node = 1 - } } } } @@ -486,43 +582,53 @@ def Build_CK(Map conf=[:]){ echo "The job was cancelled or aborted" throw e } - catch(Exception ex) { - retimage = docker.build("${image}", dockerArgs + " --no-cache .") - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo |tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ - throw new Exception ("GPU not found") - } - else{ - echo "GPU is OK" - } - if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ - navi_node = 1 - } - } - } - } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { + //check whether to run performance tests on this node + def do_perf_tests = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx1201" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){ + do_perf_tests = 1 + echo "Stash profiler and run performance tests" + } cmake_build(conf) dir("build"){ //run tests and examples - sh 'make -j check' - if (navi_node == 0 ){ + //sh 'make -j check' + if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on Navi nodes - sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash "ckProfiler.tar.gz" + //do not stash profiler on nodes where we don't need to run performance tests + sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' + stash name: "ckProfiler.tar.gz" } - if (params.RUN_FULL_QA){ - // build deb packages - sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - archiveArtifacts artifacts: 'composablekernel-tests_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash "ckprofiler_0.2.0_amd64.deb" + if (params.RUN_FULL_QA && do_perf_tests == 0 ){ + // build deb packages for all gfx9 targets and prepare to export + sh 'make -j package' + archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' + archiveArtifacts artifacts: 'composablekernel-tests_*.deb' + sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' + stash name: "ckprofiler_0.2.0_amd64.deb" + } + } + if (params.hipTensor_test && do_perf_tests == 0 ){ + //build and test hipTensor + sh """#!/bin/bash + rm -rf "${params.hipTensor_branch}".zip + rm -rf hipTensor-"${params.hipTensor_branch}" + wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip + unzip -o "${params.hipTensor_branch}".zip + """ + dir("hipTensor-${params.hipTensor_branch}"){ + sh """#!/bin/bash + mkdir -p build + ls -ltr + CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" + cmake --build build -- -j + """ + } + dir("hipTensor-${params.hipTensor_branch}/build"){ + sh 'ctest' } } } @@ -562,7 +668,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) } @@ -576,22 +682,35 @@ def process_results(Map conf=[:]){ timeout(time: 1, unit: 'HOURS'){ try{ dir("script"){ + if (params.RUN_CK_TILE_FMHA_TESTS){ + try{ + unstash "perf_fmha_fwd_gfx942.log" + unstash "perf_fmha_bwd_gfx942.log" + unstash "perf_fmha_fwd_gfx90a.log" + unstash "perf_fmha_bwd_gfx90a.log" + } + catch(Exception err){ + echo "could not locate the FMHA performance logs: ${err.getMessage()}." + } + } if (params.RUN_FULL_QA){ // unstash perf files to master + unstash "ckprofiler_0.2.0_amd64.deb" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" unstash "perf_gemm.log" unstash "perf_resnet50_N256.log" unstash "perf_resnet50_N4.log" unstash "perf_batched_gemm.log" unstash "perf_grouped_gemm.log" - unstash "perf_conv_fwd.log" - unstash "perf_conv_bwd_data.log" + unstash "perf_grouped_conv_fwd.log" + unstash "perf_grouped_conv_bwd_data.log" + unstash "perf_grouped_conv_bwd_weight.log" unstash "perf_gemm_bilinear.log" unstash "perf_reduction.log" unstash "perf_splitK_gemm.log" unstash "perf_onnx_gemm.log" + unstash "perf_mixed_gemm.log" sh "./process_qa_data.sh" - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } else{ // unstash perf files to master @@ -603,18 +722,24 @@ def process_results(Map conf=[:]){ } } catch(e){ - echo "throwing error exception while processing performance test results" + echo "Throwing error exception while processing performance test results" echo 'Exception occurred: ' + e.toString() throw e } + finally{ + echo "Finished processing performance test results" + } } } } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 - 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true + 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true;RUN_CODEGEN_TESTS=true + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true + 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false + 0 13 * * * % BUILD_LEGACY_OS=true''' : "" pipeline { agent none @@ -629,26 +754,86 @@ pipeline { name: "BUILD_DOCKER", defaultValue: false, description: "Force building docker image (default: false), set to true if docker image needs to be updated.") + string( + name: 'USE_CUSTOM_DOCKER', + defaultValue: '', + description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '5.6', - description: 'Specify which ROCM version to use: 5.6 (default).') + defaultValue: '6.2', + description: 'Specify which ROCM version to use: 6.2 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', - description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') + description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( name: 'BUILD_COMPILER', - defaultValue: 'hipcc', - description: 'Specify whether to build CK with hipcc (default) or with clang.') + defaultValue: '/opt/rocm/llvm/bin/clang++', + description: 'Build CK with /opt/rocm/bin/hipcc, /llvm-project/build/bin/clang++, or with /opt/rocm/llvm/bin/clang++ (default).') booleanParam( name: "RUN_FULL_QA", defaultValue: false, description: "Select whether to run small set of performance tests (default) or full QA") + booleanParam( + name: "DL_KERNELS", + defaultValue: false, + description: "Select whether to build DL kernels (default: OFF)") + booleanParam( + name: "hipTensor_test", + defaultValue: false, + description: "Use the CK build to verify hipTensor build and tests (default: OFF)") + string( + name: 'hipTensor_branch', + defaultValue: 'mainline', + description: 'Specify which branch of hipTensor to use (default: mainline)') + booleanParam( + name: "USE_SCCACHE", + defaultValue: true, + description: "Use the sccache for building CK (default: ON)") + booleanParam( + name: "RUN_CPPCHECK", + defaultValue: false, + description: "Run the cppcheck static analysis (default: OFF)") + booleanParam( + name: "RUN_PERFORMANCE_TESTS", + defaultValue: true, + description: "Run the performance tests (default: ON)") + booleanParam( + name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", + defaultValue: false, + description: "Run the grouped conv large cases tests (default: OFF)") + booleanParam( + name: "RUN_CODEGEN_TESTS", + defaultValue: false, + description: "Run codegen tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_FMHA_TESTS", + defaultValue: false, + description: "Run the ck_tile FMHA tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_GEMM_TESTS", + defaultValue: false, + description: "Run the ck_tile GEMM tests (default: OFF)") + booleanParam( + name: "BUILD_INSTANCES_ONLY", + defaultValue: false, + description: "Test building instances for various architectures simultaneously (default: OFF)") + booleanParam( + name: "BUILD_GFX12", + defaultValue: false, + description: "Build CK and run tests on gfx12 (default: OFF)") + booleanParam( + name: "NINJA_BUILD_TRACE", + defaultValue: false, + description: "Generate a ninja build trace (default: OFF)") + booleanParam( + name: "BUILD_LEGACY_OS", + defaultValue: false, + description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -657,30 +842,61 @@ pipeline { dbsshport = "${dbsshport}" dbsshuser = "${dbsshuser}" dbsshpassword = "${dbsshpassword}" - status_wrapper_creds = "${status_wrapper_creds}" + ck_git_creds = "${ck_git_creds}" gerrit_cred="${gerrit_cred}" DOCKER_BUILDKIT = "1" } stages{ stage("Build Docker"){ - //when { - // beforeAgent true - // expression { params.BUILD_DOCKER.toBoolean() } - //} parallel{ stage('Docker /opt/rocm'){ agent{ label rocmnode("nogpu") } steps{ buildDocker('/opt/rocm') + cleanWs() } } } } stage("Static checks") { parallel{ + stage('Clang Format and Cppcheck') { + when { + beforeAgent true + expression { params.RUN_CPPCHECK.toBoolean() } + } + agent{ label rocmnode("nogpu") } + environment{ + setup_args = "NO_CK_BUILD" + execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ + -o -not -path \'*.git*\' -iname \'*.hpp\' \ + -o -not -path \'*.git*\' -iname \'*.cpp\' \ + -o -iname \'*.h.in\' \ + -o -iname \'*.hpp.in\' \ + -o -iname \'*.cpp.in\' \ + -o -iname \'*.cl\' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ + -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \ + -D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ + -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + archiveArtifacts "build/ck_cppcheck.log" + cleanWs() + } + } stage('Clang Format') { + when { + beforeAgent true + expression { !params.RUN_CPPCHECK.toBoolean() } + } agent{ label rocmnode("nogpu") } environment{ + setup_args = "NO_CK_BUILD" execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ -o -not -path \'*.git*\' -iname \'*.hpp\' \ -o -not -path \'*.git*\' -iname \'*.cpp\' \ @@ -692,113 +908,352 @@ pipeline { | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" } steps{ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + cleanWs() } } } } - - stage("Build CK and run Tests") + stage("Run Grouped Conv Large Case Tests") { parallel { - stage("Build CK and run Tests on MI100/MI200/MI300") + stage("Run Grouped Conv Large Case Tests on gfx90a") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_GROUPED_CONV_LARGE_CASES_TESTS.toBoolean() } } - agent{ label rocmnode("gfx908 || gfx90a") } + agent{ label rocmnode("gfx90a")} environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl""" } steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() } } - stage("Build CK and run Tests on MI100/MI200") + } + } + stage("Run Codegen Tests") + { + parallel + { + stage("Run Codegen Tests on gfx90a") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_CODEGEN_TESTS.toBoolean() } } - agent{ label rocmnode("gfx908 || gfx90a") } + agent{ label rocmnode("gfx90a")} environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = "NO_CK_BUILD" + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + make -j64 check""" } steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() } } - stage("Build CK and run Tests on Navi21") + } + } + stage("Run CK_TILE_FMHA Tests") + { + parallel + { + stage("Run CK_TILE_FMHA Tests on gfx90a") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } } - agent{ label rocmnode("navi21") } + agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() } } - stage("Build CK and run Tests on Navi32") + stage("Run CK_TILE_FMHA Tests on gfx942") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } } - agent{ label rocmnode("navi32") } + agent{ label rocmnode("gfx942") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run CK_TILE_GEMM Tests") + { + parallel + { + stage("Run CK_TILE_GEMM Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_gemm_basic && \ + cd ../ && + example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE_GEMM Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_gemm_basic && \ + cd ../ && + example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() } } } } - stage("Performance Tests") + stage("Build CK and run Tests") { parallel { - stage("Run ckProfiler: gfx90*") + stage("Build CK with RHEL8") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { params.BUILD_LEGACY_OS.toBoolean() } } - options { retry(2) } - agent{ label rocmnode("gfx908 || gfx90a")} + agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """ - } + def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3" + setup_args = """ -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " \ + -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + execute_args = " " + } steps{ - runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) + cleanWs() + } + } + stage("Build CK with SLES15") + { + when { + beforeAgent true + expression { params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_sles15_rocm6.3" + setup_args = """ -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " \ + -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + execute_args = " " + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) + cleanWs() + } + } + stage("Build CK for all gfx9 targets") + { + when { + beforeAgent true + expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on gfx90a") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx90a" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK instances for different targets") + { + when { + beforeAgent true + expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Build CK and run Tests on gfx1030") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx1030") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1030" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on gfx1101") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx1101") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1101" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on gfx1201") + { + when { + beforeAgent true + expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx1201") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1201" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1201" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } + } + } + + stage("Performance Tests") + { + parallel + { stage("Run ckProfiler: gfx90a") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - options { retry(2) } + options { retry(1) } agent{ label rocmnode("gfx90a")} environment{ - setup_args = """ -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On """ + setup_args = "NO_CK_BUILD" } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } } @@ -808,9 +1263,14 @@ pipeline { parallel { stage("Process results"){ + when { + beforeAgent true + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } agent { label 'mici' } steps{ process_results() + cleanWs() } } } diff --git a/LICENSE b/LICENSE index e03fddaf78080705d26ec277629cfb8010c077bf..581b5efde535f686aa0a584709fccaa92353d125 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index c2b493db11596d489c6bc5e4fcf9549080998c01..302173dc1729fde4c1f81277193031a291796149 100644 --- a/README.md +++ b/README.md @@ -1,139 +1,197 @@ # Composable Kernel -## Methodology +> [!NOTE] +> The published documentation is available at [Composable Kernel](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/) in an organized, easy-to-read format, with search and a table of contents. The documentation source files reside in the `docs` folder of this repository. As with all ROCm projects, the documentation is open source. For more information on contributing to the documentation, see [Contribute to ROCm documentation](https://rocm.docs.amd.com/en/latest/contribute/contributing.html). -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +The Composable Kernel (CK) library provides a programming model for writing performance-critical +kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library +uses general purpose kernel languages, such as HIP C++. + +CK uses two concepts to achieve performance portability and code maintainability: -CK utilizes two concepts to achieve performance portability and code maintainability: * A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". +* Algorithm complexity reduction for complex machine learning (ML) operators. This uses an innovative + technique called *Tensor Coordinate Transformation*. ![ALT](/docs/data/ck_component.png "CK Components") -## Code Structure +The current CK library is structured into four layers: -Current CK library are structured into 4 layers: -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer +* Templated Tile Operators +* Templated Kernel and Invoker +* Instantiated Kernel and Invoker +* Client API ![ALT](/docs/data/ck_layer.png "CK Layers") -## Documentation +## General information -Run the steps below to build documentation locally. +To build our documentation locally, use the following code: -``` +``` bash cd docs pip3 install -r sphinx/requirements.txt python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html ``` -## Contributors +You can find a list of our developers and contributors on our [Contributors](/CONTRIBUTORS.md) page. -The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) +```note +If you use CK, cite us as follows: -## Citation - -If you use CK, please use following citations: -* CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???) +* [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???): + This paper will be available on arXiv soon. * [CITATION.cff](/CITATION.cff) +``` -## License +CK is released under the **[MIT license](/LICENSE)**. -CK is released under the MIT license. [License File](/LICENSE) +## Building CK +We recommend building CK inside Docker containers, which include all necessary packages. Pre-built +Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composable_kernel/tags). -# Build CK +1. To build a new Docker image, use the Dockerfile provided with the source code: -## Build docker image + ```bash + DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . + ``` -```bash -DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . -``` -Pre-built dockers are available from this public repo: -https://hub.docker.com/r/rocm/composable_kernel/tags +2. Launch the Docker container: -## Launch docker + ```bash + docker run \ + -it \ + --privileged \ + --group-add sudo \ + -w /root/workspace \ + -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ + ck:latest \ + /bin/bash + ``` -```bash -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -ck:latest \ -/bin/bash -``` +3. Clone CK source code from the GitHub repository and start the build: -## Build CK + ```bash + git clone https://github.com/ROCm/composable_kernel.git && \ + cd composable_kernel && \ + mkdir build && \ + cd build + ``` -```bash -mkdir build && cd build + You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s) you want + to run CK on. You can specify single or multiple architectures. If you specify multiple architectures, + use a semicolon between each; for example, `gfx908;gfx90a;gfx940`. -# Need to specify target ID, example below is for gfx908 and gfx90a + ```bash + cmake \ + -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx908;gfx90a" \ + .. + ``` -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_BUILD_TYPE=Release \ --D GPU_TARGETS="gfx908;gfx90a" \ -.. -``` + If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets + supported by the current compiler (this may take a long time). + Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line. -If GPU_TARGETS is not set on the cmake command line, CK will be built for all targets supported by the -current compiler. + NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the + architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you + want to build the library for a list of different architectures, + you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`. +4. Build the entire CK library: -Additional cmake flags can be used to significantly speed-up the build: + ```bash + make -j + ``` -INSTANCES_ONLY (by default is OFF) must be set to ON in order to build only the instances and library -while skipping all tests, examples, and profiler. This is useful for libraries that use CK as a dependency. +5. Install CK: -DTYPES (by default not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instances -of select data types only. Currently, building of int8 instances is taking a lot of time (the compiler fix is in the works). + ```bash + make -j install + ``` -DL_KERNELS (by default is OFF) must be set to ON in order to build the gemm_dl and batched_gemm_multi_d_dl -instances. Those instances are only needed for the NAVI2x platforms. +## Optional post-install steps -### Build examples and tests +* Build examples and tests: -```bash - make -j examples tests - make test -``` + ```bash + make -j examples tests + ``` + +* Build and run all examples and tests: -Instructions for running each individual examples are under [example](/example) + ```bash + make -j check + ``` + You can find instructions for running each individual example in [example](/example). -## Build ckProfiler +* Build ckProfiler: + + ```bash + make -j ckProfiler + ``` + + You can find instructions for running ckProfiler in [profiler](/profiler). + +Note the `-j` option for building with multiple threads in parallel, which speeds up the build significantly. +However, `-j` launches unlimited number of threads, which can cause the build to run out of memory and +crash. On average, you should expect each thread to use ~2Gb of RAM. +Depending on the number of CPU cores and the amount of RAM on your system, you may want to +limit the number of threads. For example, if you have a 128-core CPU and 128 Gb of RAM it's advisable to use `-j32`. + +Additional cmake flags can be used to significantly speed-up the build: + +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build + instances of select data types only. The main default data types are fp32 and fp16; you can safely skip + other data types. + +* `DL_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dl` or + `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most + other platforms have faster instances, such as `xdl` or `wmma`, available. + +* `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances, + such as `gemm_universal` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not + have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on + architectures like the MI100/MI200 for the functional support only. + +## Using sccache for building + +The default CK Docker images come with a pre-installed version of sccache, which supports clang +being used as hip-compiler (" -x hip"). Using sccache can help reduce the time to re-build code from +hours to 1-2 minutes. In order to invoke sccache, you need to run: ```bash - make -j ckProfiler + sccache --start-server ``` -Instructions for running ckProfiler are under [profiler](/profiler) -## Install CK +then add the following flags to the cmake command line: ```bash -make install + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache ``` +You may need to clean up the build folder and repeat the cmake and make steps in order to take +advantage of the sccache during subsequent builds. + ## Using CK as pre-built kernel library -Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) +You can find instructions for using CK as a pre-built kernel library in [client_example](/client_example). -## Contributing +## Contributing to CK -When you contribute to Composable Kernel, make sure to run `clang-format` on all the changed files. We highly recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run: +When you contribute to CK, make sure you run `clang-format` on all changed files. We highly +recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run: ```bash sudo script/install_precommit.sh ``` -This way, `pre-commit` will add the appropriate hooks to your local repository and automatically run `clang-format` (and possibly additional checks) before any commit is created. +With this approach, `pre-commit` adds the appropriate hooks to your local repository and +automatically runs `clang-format` (and possibly additional checks) before any commit is created. If you need to uninstall hooks from the repository, you can do so by running the following command: @@ -141,14 +199,5 @@ If you need to uninstall hooks from the repository, you can do so by running the script/uninstall_precommit.sh ``` -If for any reason, you need to temporarily disable precommit hooks, you can add the `--no-verify` option to the `git commit` command. - -## Caveat -### Kernel Timing and Verification - -CK's own kernel timer will warn up kernel once, and then run it multiple times -to get average kernel time. For some kernels that use atomic add, this will cause -output buffer to be accumulated multiple times, causing verification failure. -To work around it, do not use CK's own timer and do verification at the same time. -CK's own timer and verification in each example and ckProfiler can be enabled or -disabled from command line. +If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the +`git commit` command. diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt index 9e741192f90b8216e4b3abe32ae8971fb45ddfee..6c4103cda879128992bfe2b0ff4757f8f8b95a9c 100644 --- a/client_example/01_gemm/CMakeLists.txt +++ b/client_example/01_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_gemm gemm.cpp) -target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index c37f208db1cee9bfff1fff469fd79d059fb179f0..e63cda616206daf7d874d2e59ddb4d294dbc5f0a 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -83,7 +83,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } @@ -185,6 +185,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index ba2952022233b3ffd21fc245b8fd199a5ec5b096..4ba86026b24f48a07f6714ec9d75dc0670a40267 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,27 +1,29 @@ -add_custom_target(client_gemm_fastgelu_examples) +if(GPU_TARGETS MATCHES "gfx9") + add_custom_target(client_gemm_fastgelu_examples) -add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) -target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) + target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu + add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) -add_custom_target(client_gemm_fastgelu_generic_examples) + add_custom_target(client_gemm_fastgelu_generic_examples) -add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) -target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) + target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) +endif() diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 756889562e84c66efb5f972621bcb61edda3af82..5809681661b91bf329a9df5a699c96531b408c7d 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -92,7 +92,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp index 2ed942f0adf64df90ea1046d1de191ce2247d7c4..3cc4313aabb9b2a78960bd2f71775200fe57f538 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 8d2a8c234aae63a3566478b5aa9588389a247d4e..1fd80d10c736ab87d7d9e064cb303f65962cb62e 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } @@ -197,6 +197,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp index 644b428fc9f51b28a0f81ed7d79ac741ca6fbcbe..e54bcfd989078523e41458e3c7b172838f72f59b 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -89,7 +89,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index c02df018fd35c6d37a2c6f9b9fded6390c9afb19..47fd58f691020decb602f320fc18e281dba82295 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -84,7 +84,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } @@ -190,6 +190,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp index 482e93b421f7700cdf37ee014cf407ca3c63555b..f43554f2bde85811ea46a6ca539834d6940608a9 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -85,7 +85,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index b38698d9064e6fa3cbd2e01c8818c0ebfe361cb3..8fedc846358583f9ff073921c556199ace0d95c8 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) -target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) + target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) -add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) -target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) + target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +endif() diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp index 58c91f903bc7e2f3f296b06c746b1629513bc7f2..020f047d1af7babfe5862a7106fdd7f48180f910 100644 --- a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -8,7 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp" @@ -17,6 +17,8 @@ using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; + using ADataType = F16; using BDataType = F16; using BiasDataType = F32; @@ -191,7 +193,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index 3d5fb6004844af269f7786ae05de8c20cc720633..7d5ef5f9bfe909f4a37a3fcb27756c2aaebb5df1 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -78,7 +78,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } @@ -200,6 +200,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index 7ffedfeef36839c1c8bb942ffae11e5f6dcadf22..13c0375846064110e1ee532f035cda950071cc8d 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,15 +1,16 @@ -add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) -target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) + target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) -target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_operations) + add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) + target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) -target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_operations) + add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) + target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) -target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_operations) - -add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) -target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations) + add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) + target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) + target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/04_contraction/contraction_bilinear_fp32.cpp b/client_example/04_contraction/contraction_bilinear_fp32.cpp index 89f834b9824e134f8f0aeed8aa54f78a5c8824a3..f1881e60a07f20f2ea774024b691d1736f7a83d5 100644 --- a/client_example/04_contraction/contraction_bilinear_fp32.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp64.cpp b/client_example/04_contraction/contraction_bilinear_fp64.cpp index 1aa3ba7de597a0b97a295b0f7ee7ad21b1e9cd80..8b499eee215579df3db07bf714300ff8ec48184c 100644 --- a/client_example/04_contraction/contraction_bilinear_fp64.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp index f8ea2258c2ba262e9db42f1b4dc92ff16cdc6286..a5ef40a2dc22a7df2f3c6b60236a5640a5003180 100644 --- a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp +++ b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp32.cpp b/client_example/04_contraction/contraction_scale_fp32.cpp index ba7b0633c33aabeeb06547edfb08f506e637e599..5c06d31488d82073b555dcb7d549a7bc01ffdcf2 100644 --- a/client_example/04_contraction/contraction_scale_fp32.cpp +++ b/client_example/04_contraction/contraction_scale_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp64.cpp b/client_example/04_contraction/contraction_scale_fp64.cpp index 24e52eb5aa423339ff96ad0914dc479d715fe7b7..14fb8741e75dbe79e46ca2d9c7d8129621fb8166 100644 --- a/client_example/04_contraction/contraction_scale_fp64.cpp +++ b/client_example/04_contraction/contraction_scale_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index b582b485d4ce46951aaed98c6256e1c997388bd3..b7b3c830ed5beaab6bb508d99799c998dd80e340 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,2 +1,11 @@ -add_executable(client_layernorm2d layernorm2d.cpp) -target_link_libraries(client_layernorm2d PRIVATE composable_kernel::device_operations) +add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) +target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) + +add_executable(client_layernorm2d_bwd_gamma_beta layernorm2d_bwd_gamma_beta.cpp) +target_link_libraries(client_layernorm2d_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + +add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) +target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_layernorm4d_fwd layernorm4d_fwd.cpp) +target_link_libraries(client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp deleted file mode 100644 index 4af4d7abe8ee6c5120f6060c78141021052cb619..0000000000000000000000000000000000000000 --- a/client_example/05_layernorm/layernorm2d.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - -int main(int argc, char* argv[]) -{ - ck::index_t M = 1024; - ck::index_t N = 1024; - ck::index_t Stride = 1024; - - auto xy_size = (M - 1) * Stride + N; - - SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); - SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); - SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); - SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); - - using DeviceOp = ck::tensor_operation::device::DeviceNormalization; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - bool found = false; - int best_op_id = -1; - float best_ave_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - - auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths - {Stride, 1}, // xStrides - {0, 1}, // gammaStrides - {0, 1}, // betaStrides - {Stride, 1}, // yStrides - {1}, // reduceDims - 1e-4, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + - sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " - << op_name << std::endl; - - if(ave_time < best_ave_time) - { - found = true; - best_op_id = i; - best_op_name = op_name; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - } - else - { - std::cout << op_name << " does not support this problem" << std::endl; - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " - << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - - auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths - {Stride, 1}, // xStrides - {1}, // gammaStrides - {1}, // betaStrides - {Stride, 1}, // yStrides - {1}, // reduceDims - 1e-4, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } - - return 0; -} diff --git a/client_example/05_layernorm/layernorm2d_bwd_data.cpp b/client_example/05_layernorm/layernorm2d_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec02cb2c4ee8c2dc345ca9145e78d220ca716bd2 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_data.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * M * N); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(GammaDataType) * N + sizeof(MeanInvStdDataType) * M * 2 + + sizeof(DXDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1d1ebefd5b18355a6e26e7ef29cb31d656f6c9ef --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * N); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * N); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(MeanInvStdDataType) * M * 2 + sizeof(DGammaDataType) * N + + sizeof(DBetaDataType) * N; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d_fwd.cpp b/client_example/05_layernorm/layernorm2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22599f43cabc3dc31e8f6a8a05adf275faac1687 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_fwd.cpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = ck::half_t; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto xy_size = (M - 1) * Stride + N; + + SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M); +#endif + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {Stride, 1}, // xStrides + {0, 1}, // gammaStrides + {0, 1}, // betaStrides + {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides + {1}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * M * 2; +#endif + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {Stride, 1}, // xStrides + {0, 1}, // gammaStrides + {0, 1}, // betaStrides + {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides + {1}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm4d_fwd.cpp b/client_example/05_layernorm/layernorm4d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c80fd31b6ee77099338799cc2b4713f1ea274fae --- /dev/null +++ b/client_example/05_layernorm/layernorm4d_fwd.cpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = ck::half_t; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 256; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t C = 8; + + std::vector strideXY = {H * W * C, W * C, C, 1}; + std::vector strideGammaBeta = {0, W * C, C, 1}; + std::vector strideSaveMeanInvStd = {1}; + + SimpleDeviceMem x_device_buf(sizeof(XDataType) * N * H * W * C); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * H * W * C); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * H * W * C); + SimpleDeviceMem y_device_buf(sizeof(YDataType) * N * H * W * C); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N); +#endif + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, C}, // lengths + strideXY, // xStrides + strideGammaBeta, // gammaStrides + strideGammaBeta, // betaStrides + strideXY, // yStrides + strideSaveMeanInvStd, // save_mean Strides + strideSaveMeanInvStd, // save_inv_std Strides + {1, 2, 3}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(XDataType) * N * H * W * C + sizeof(GammaDataType) * H * W * C + + sizeof(BetaDataType) * H * W * C + sizeof(YDataType) * N * H * W * C; + +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * N * 2; +#endif + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, C}, // lengths + strideXY, // xStrides + strideGammaBeta, // gammaStrides + strideGammaBeta, // betaStrides + strideXY, // yStrides + strideSaveMeanInvStd, // save_mean Strides + strideSaveMeanInvStd, // save_inv_std Strides + {1, 2, 3}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/06_softmax/CMakeLists.txt b/client_example/06_softmax/CMakeLists.txt index b38a0fd9e27570e62df7f701572a24fb4ee842f9..24d30f475ec8c44b41375d9d972f04e6aeec4364 100644 --- a/client_example/06_softmax/CMakeLists.txt +++ b/client_example/06_softmax/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_softmax4d softmax4d.cpp) -target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_operations) +target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations) diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index 2ccad27a88757b576a050d0558acc4108857cbd1..eaddbf98ee376cac0fbac559dabefc758f9c6af5 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -140,6 +140,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index fce7e91c1ef5b23fcfecfe99cc49263a1d779d76..c953e21d0266f61b1b9fc99de9252a4f00bd57cd 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,5 +1,25 @@ -add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) -target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) + target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) -add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) -target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_operations) + add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) + target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) + + if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() +endif() diff --git a/client_example/07_grouped_convnd_fwd/common.hpp b/client_example/07_grouped_convnd_fwd/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..729af0b88b8d48a3361098aac1a155ad4e013308 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/common.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 70be0101c6d92512ab28a72797d2b8c46fb55281..d3a3111e945da61d30164aec4ee23b1c1c5e4689 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - std::array in_lengths{G, N, Wi, C}; - std::array in_strides{0, 0, 0, 1}; - - std::array wei_lengths{G, K, X, C}; - std::array wei_strides{0, 0, 0, 1}; - - std::array out_lengths{G, N, Wo, K}; - std::array out_strides{0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW - std::rotate(rbegin(in_lengths), - std::next(rbegin(in_lengths)), - std::next(rbegin(in_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(in_strides), - std::next(rbegin(in_strides)), - std::next(rbegin(in_strides), NumDimSpatial + 1)); - std::rotate(rbegin(wei_lengths), - std::next(rbegin(wei_lengths)), - std::next(rbegin(wei_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(wei_strides), - std::next(rbegin(wei_strides)), - std::next(rbegin(wei_strides), NumDimSpatial + 1)); - std::rotate(rbegin(out_lengths), - std::next(rbegin(out_lengths)), - std::next(rbegin(out_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(out_strides), - std::next(rbegin(out_strides)), - std::next(rbegin(out_strides), NumDimSpatial + 1)); - - std::array filter_strides{1}; - std::array filter_dilations{1}; - std::array input_left_pads{1}; - std::array input_right_pads{1}; - - SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C + - sizeof(WeiDataType) * G * K * X * C + - sizeof(OutDataType) * G * N * Wo * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 57a210fa1f5f0d3f0eca014e08ccd119d29e4fd6..fb8a410ab39b71b56a4dfb4df8b62aab16e39d35 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W static constexpr ck::index_t Ho = 28; // output H static constexpr ck::index_t Wo = 28; // output W -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space - // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW - // Hence, we need to adjust the order of stride - std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; - std::array wei_lengths{G, K, C, Y, X}; - std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; - std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; - - std::array filter_strides{1, 1}; - std::array filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; - - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * N * Ho * Wo * G * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..983e0d083c59f6961030ba36b4a3fb09bf57f5f6 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b195d87bbcb723b0e3214b365a177323ff39446c --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::bf8_t; +using BComputeType = ck::f8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2506e29e0e5eec5b7cf630bc03ab559ba6c2a266 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8508dc9c55dc205a6548a14147e689b70e19e8d3 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 862b9ed5b70ee8a35c4ba44e2ccf1441e1532bd3..4bcde367dc661cb5211802c9f06d64047afaa218 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_fused_attention fused_attention.cpp) -target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_fused_attention fused_attention.cpp) + target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_fused_attention_bias fused_attention_bias.cpp) -target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations) + add_executable(client_fused_attention_bias fused_attention_bias.cpp) + target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/08_fused_attention/fused_attention.cpp b/client_example/08_fused_attention/fused_attention.cpp index df6bc11a70d32df221612466a8af0fbcd9cafb1c..339d92e756c5f7c5ddd541650f7b2fe35f1334f5 100644 --- a/client_example/08_fused_attention/fused_attention.cpp +++ b/client_example/08_fused_attention/fused_attention.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention_bias.cpp b/client_example/08_fused_attention/fused_attention_bias.cpp index 6c9f3bc8f6f5a3c06f339f1246b5b3985e11d2d8..a1200a9db44e0a954650d75caac159f5c193d275 100644 --- a/client_example/08_fused_attention/fused_attention_bias.cpp +++ b/client_example/08_fused_attention/fused_attention_bias.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index ac11aad45de84b8a2e042d8a1a4bb059b889f3f2..d2d3a427e8085f289b98f84463f089aebe567b40 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,22 +1,22 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) +if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)) + add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_gemm_quantization gemm_quantization.cpp) -target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_gemm_quantization gemm_quantization.cpp) + target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) endif() diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index cd504e942e943b8f174442554eca379cd908ba6e..08919401cdd2a2ab88a3943504937e347721a333 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -80,7 +80,7 @@ int main(int argc, char* argv[]) SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NumDimSpatial, InLayout, WeiLayout, diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index f4aa3666b1c15eccd37bdef549f8402bcd1b252b..1d502ba4a2e1327f2cefcceb168e35135183b3ad 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -78,18 +78,18 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp index ebdbbf52c0ca257dd8f672a48b32f80fa5bb8816..5b9c9d370883851bf1585287133d7d57aa7241e9 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -83,7 +83,7 @@ int main(int argc, char* argv[]) SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NumDimSpatial, InLayout, WeiLayout, diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp index 9d60baee06fddb27451f1f51642bad738739c4bb..7c40aa4e60f7ed552c145c244c881a1d80c5b7c7 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -79,18 +79,18 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp index dd81d9ee6b6dc0ee453cac61d7e795c79ba3c9fa..3777cd5e1b98ea785ab750c1144996f2f9b4f772 100644 --- a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -76,19 +76,19 @@ int main(int argc, char* argv[]) SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp index 9c088a21d38e7f62c8eee35bf4c557a0bb2209bd..1fbb1ddea4c72f30d7fecbe08f7b4ee8cc084558 100644 --- a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -72,18 +72,18 @@ int main(int argc, char* argv[]) SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/gemm_quantization.cpp b/client_example/09_quantization/gemm_quantization.cpp index b14e68fa082f8f9d05ab5f00471a2fa82d1d113b..d2fadd8d91419ce729cbf0320e9700eec4644e3a 100644 --- a/client_example/09_quantization/gemm_quantization.cpp +++ b/client_example/09_quantization/gemm_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt b/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt deleted file mode 100644 index e564f3180d8c9295356d90266f2159de47aac060..0000000000000000000000000000000000000000 --- a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) -target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) diff --git a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d10c39ed801df6ffd3f833f53c5a6253e7c34bcc --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,10 @@ +add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) +target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations) + +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp) + target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_conv_operations) +endif() \ No newline at end of file diff --git a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp similarity index 99% rename from client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp rename to client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index 1b2e8abc201c2aed2cd2eebccb68405a25033a43..ae5f1b6f6ef27a54e0de8fce8004bdb8e70d7658 100644 --- a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93709a79019cad1a783192423db53e0105e6c011 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a62a1d911bc9cd535874d56671971b833c150341 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough, + ck::bf8_t, + ck::f8_t>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt index 82162b606a6ee055cb985b960366a93fa4dd21e3..60a6dc10210ab6e5a4285a3d377c52005955c936 100644 --- a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt +++ b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt @@ -3,7 +3,12 @@ add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_f add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp) -target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations) + +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp) + target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) +endif() \ No newline at end of file diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index 4292cded2071526381db36c31a5e660f7087cbce..541a0a19a07fb778eeff2f3fdf25b25cfbc1ab4f 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -85,7 +85,9 @@ template + typename OutLayout, + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool run_grouped_conv_bwd_weight( const std::array& input_lengths, const std::array& input_strides, @@ -113,7 +115,9 @@ bool run_grouped_conv_bwd_weight( OutDataType, PassThrough, PassThrough, - PassThrough>; + PassThrough, + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); @@ -156,6 +160,10 @@ bool run_grouped_conv_bwd_weight( auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp index e6d427faf4d75033a931ec2bfeeb72f4cec90e31..a51aab483e55daab7f2b2d0d2974c743941c15e7 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp index 4201ea61b4ee3c62f0bb8034b6e0510242000942..705ad21ae8d257c11bdf465d4215d9c04ee4e84b 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index 3ae46bcd5566cd3958e1b4e666119ec55e3528e9..5ed3896e7a7719fe6b9390fc1d9da81465665ff4 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..868e0e290310701e93b8d56c839440d74b4b2645 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::bf8_t; +using BComputeType = ck::f8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 8; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 128; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; +static constexpr std::array input_strides{ + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; +static constexpr std::array output_strides{ + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1, 1}; +static constexpr std::array conv_filter_dilations{1, 1, 1}; +static constexpr std::array input_left_pads{1, 1, 1}; +static constexpr std::array input_right_pads{1, 1, 1}; + +int main() +{ + return run_grouped_conv_bwd_weight(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index 2eb869f3923ad0d743d5d0061842683e2fd95f87..d5f1fc331bc9e02f7c87bb39ddb1b6c0cebe8259 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/12_elementwise_normalization/CMakeLists.txt b/client_example/12_elementwise_normalization/CMakeLists.txt index 1ba0e1279a402ae00a6e3aead9de6c12c61e14c9..738647de595ad0151a12751d0d9c6b72a765ae2c 100644 --- a/client_example/12_elementwise_normalization/CMakeLists.txt +++ b/client_example/12_elementwise_normalization/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp) -target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_operations) +target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index bc4a6fe0bfa9e118bbd6ba32ecc7dd68f3b8b2c3..69d7c8936ce1317051cd3d2e9d5bc2543df58e50 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -142,6 +142,7 @@ int main() << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/13_batchnorm/CMakeLists.txt b/client_example/13_batchnorm/CMakeLists.txt index fc4f9d395c4a8b45e34db2ad0c5622484d379dba..420ea25752082117c6e99d8fdaebf3ed11883a3d 100644 --- a/client_example/13_batchnorm/CMakeLists.txt +++ b/client_example/13_batchnorm/CMakeLists.txt @@ -1,6 +1,6 @@ add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp) add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp) add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp) -target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations) -target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations) -target_link_libraries(client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_operations) +target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_other_operations) +target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_other_operations) +target_link_libraries(client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp index 1ed36e0f50f05483a938a9a7b1c433df003b3b39..4f6985a5142d037b4c7d04a2972ede0ff82cf109 100644 --- a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp index f9af011c8480a5caa510a66467edc33e886fe0e2..9fa82523be34d5c5abb46fd3027594eaed1223f8 100644 --- a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp index 5e6627ce14d113224c7b0acb4fea69cb36c1f369..6393cf3e651a8d9d91f53e8e36385dc2db5d7c0e 100644 --- a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/14_instance_id/CMakeLists.txt b/client_example/14_instance_id/CMakeLists.txt index 87b2a9a0cbc82a82c9567266178ee431fb9149e7..6ba0e59f5aebeb33f63333d026eeaad24e2258fd 100644 --- a/client_example/14_instance_id/CMakeLists.txt +++ b/client_example/14_instance_id/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp) -target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_operations) +target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp index d45782d8e0ff37027a204c6820447286581f1138..2a565738a78f4b62a17b0e1af62588c50a498f84 100644 --- a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp +++ b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index 8a60a71674f84f6005deadf3744b7df346c0ac3c..8fc62bc2bb0e8ecaf5dab785b1b6eb88698c29b8 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) -add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) + add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) -target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_operations) + target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp index 5210567241ebd9ce84b8b0b576b295594bf8578b..29dbc97f408af7e11a76ba3d3844fd6dd6c314f1 100644 --- a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp index 441bdfe7bec3bd12fe948c2efc892ad5cc4d0184..b53e892fdc94f10e671c8893adf88603cbaadfdf 100644 --- a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/15_gemm_add_multiply/CMakeLists.txt b/client_example/15_gemm_add_multiply/CMakeLists.txt index fd2dcf9614016b93b43a52c2940d1701ee623584..a683f78571a8b9cb85d201818e27538d521ef567 100644 --- a/client_example/15_gemm_add_multiply/CMakeLists.txt +++ b/client_example/15_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ - -add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) -target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) + target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index c74d7c6bd8cf9cec54168792a8718768827d7b33..a8c2ae1214ae6c01f95297c2eb481bb4d18d8541 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/15_reduce/CMakeLists.txt b/client_example/15_reduce/CMakeLists.txt index d52675ba834e666d0a480d5c778baeacb223a24a..a944af5e54d27b014956951a78daa2bacac9573b 100644 --- a/client_example/15_reduce/CMakeLists.txt +++ b/client_example/15_reduce/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_reduce_nhwc_c reduce_nhwc_c.cpp) -target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_operations) +target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_reduction_operations) diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index b45b72f0de0199daa88e9e42be320e2669398dfe..e2b1fbcb545aa63b689b021ff77849cd1443fcb2 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e2580a370ca480e0b9ec101e5715ac9a631a72f6..8c1372e74152b6fd3d32a72a063baa7761c08fdf 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -1,5 +1,15 @@ -add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) -add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) +if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) + target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_conv_operations) -target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) +endif() + +if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) + target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) +endif() + +if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) + target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index 449c9466e829baa8658f17fe8e9ae86c5a297238..ee408c7443fbedde726ca7263a0c63c7d509d1e7 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -11,7 +11,7 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -94,7 +94,9 @@ template + ck::index_t NumNonSpatialDim = 3, + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -173,18 +175,20 @@ bool run_grouped_conv_fwd(std::array(out_lengths, wei_lengths); std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp index d4455df628afcb6ec721b5f884b79e4572a5037c..10033822ddcbce1eeccfec79efa130d40e5fb932 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..22ba25efb986c95acc9110877966f0602d3ce0d2 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 7e8c98b6037418001571707a82aa14987059fb4f..a739f9d05b28ad2cb0d9e52aa06c66c04b930ba5 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index 659e6769d8e69f227ea1ef5920c3951bba706f0d..39bef71814a0909013b284f58fb0255c5d14eddf 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) -target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) + target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp index 7ba3224fc3244ab3bd472b98864efa42ae132fde..6a745e1ab008f638f5ab6a1de728376f742d7dc1 100644 --- a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp +++ b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index 17c88cb61bc5daaa3ac0c6a92459c147691be2fe..e04c26d8e7bc220825912e0a38297b75b6a93370 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,2 +1,8 @@ -add_executable(client_groupnorm_swish groupnorm_swish.cpp) -target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations) +add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) +target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) + +add_executable(client_groupnorm_bwd_gamma_beta groupnorm_bwd_gamma_beta.cpp) +target_link_libraries(client_groupnorm_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + +add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp) +target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/18_groupnorm/groupnorm_bwd_data.cpp b/client_example/18_groupnorm/groupnorm_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bcfa5f7dc604989af550b91747713eeb3053a8b4 --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_data.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = {0, 0, 0, C, 1}; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * G * C); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * length); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DXDataType) * length; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000000000000000000000000000000000..06ab194a8e5be8f6bc43008bdee585fe0e21fb96 --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + std::vector strideDGammaBeta = {C, 1}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * G * C); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * G * C); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DGammaDataType) * G * C + sizeof(DBetaDataType) * G * C; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish.cpp deleted file mode 100644 index e1d198d2282b85f21fb4c04afde443133c801b15..0000000000000000000000000000000000000000 --- a/client_example/18_groupnorm/groupnorm_swish.cpp +++ /dev/null @@ -1,194 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" - -using XDataType = ck::half_t; -using GammaDataType = float; -using BetaDataType = float; -using YDataType = ck::half_t; -using ComputeDataType = float; -using Swish = ck::tensor_operation::element_wise::Swish; - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - -int main(int argc, char* argv[]) -{ - ck::index_t N = 32; - ck::index_t H = 16; - ck::index_t W = 16; - ck::index_t G = 64; - ck::index_t C = 128; - - std::size_t xy_size = N * H * W * G * C; - std::size_t gamma_beta_size = G * C; - - std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; - std::vector gamma_beta_strides = {0, 0, 0, C, 1}; - - SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); - SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); - SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); - SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); - - using DeviceOp = ck::tensor_operation::device::DeviceNormalization; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - const auto& generic_op_ptr = op_ptrs[0]; - - auto generic_argument_ptr = - generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); - - if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) - { - throw std::runtime_error( - "The generic kernel instance should be able to support any input shapes"); - }; - - std::string best_op_name; - bool found = false; - int best_op_id = -1; - float best_ave_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t num_byte = - sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + - sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; - - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " - << op_name << std::endl; - - if(ave_time < best_ave_time) - { - found = true; - best_op_id = i; - best_op_name = op_name; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - } - else - { - std::cout << op_name << " does not support this problem" << std::endl; - } - } - - // run the best intance - if(found) - { - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " - << best_op_name << std::endl; - - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } - - return 0; -} diff --git a/client_example/18_groupnorm/groupnorm_swish_fwd.cpp b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..26110193d7902286b37f39616a15528d153f1df6 --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp" + +using XDataType = ck::half_t; +using GammaDataType = float; +using BetaDataType = float; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using Swish = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t xy_size = N * H * W * G * C; + std::size_t gamma_beta_size = G * C; + + std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector save_mean_inv_std_strides = {G, 1}; + + SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); + SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); +#endif + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = + generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + + sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; + +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2; +#endif + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool/CMakeLists.txt b/client_example/19_pool/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..861c1a32578e7d741c777032f7d22b63b14c098d --- /dev/null +++ b/client_example/19_pool/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) +target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_max_pool2d_bwd max_pool2d_bwd.cpp) +target_link_libraries(client_max_pool2d_bwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) +target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_avg_pool3d_bwd avg_pool3d_bwd.cpp) +target_link_libraries(client_avg_pool3d_bwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/19_pool/avg_pool3d_bwd.cpp b/client_example/19_pool/avg_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bf4b9346ecbfdf913c7838249314d9e5c938fe1 --- /dev/null +++ b/client_example/19_pool/avg_pool3d_bwd.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; + +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size) + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mMemSize_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCDHW + std::vector in_length = {N, C, Di, Hi, Wi}; + std::vector out_length = {N, C, Do, Ho, Wo}; + std::vector window_spatial_lengths = {Z, Y, X}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Di * Hi * Wi; + std::size_t out_tensor_size = N * C * Do * Ho * Wo; + + // tensor layout = NDHWC + std::vector in_tensor_stride = {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}; + std::vector out_tensor_stride = {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}; + + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + using DeviceOp = ck::tensor_operation::device:: + DeviceAvgPoolBwd<3, DOutDataType, DInDataType, DOutLayout, DInLayout>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(DInDataType) + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp b/client_example/19_pool/avg_pool3d_fwd.cpp similarity index 77% rename from client_example/19_pool_fwd/avg_pool3d_fwd.cpp rename to client_example/19_pool/avg_pool3d_fwd.cpp index db8e0569d717f08e9b64e4504d40dd23aab13129..846bd5ff4dfa93abd7d5f641bc991a18ad183bc3 100644 --- a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp +++ b/client_example/19_pool/avg_pool3d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -94,7 +94,6 @@ int main(int argc, char* argv[]) SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); - SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); using DeviceOp = ck::tensor_operation::device::DevicePoolFwdMakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - in_tensor_stride, - out_tensor_stride, - out_tensor_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - {2, 3, 4}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + nullptr, + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -184,21 +183,21 @@ int main(int argc, char* argv[]) std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - in_tensor_stride, - out_tensor_stride, - out_tensor_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - {2, 3, 4}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + nullptr, + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/19_pool/max_pool2d_bwd.cpp b/client_example/19_pool/max_pool2d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a90889656d058ee4670b72603f6c64a143b038dd --- /dev/null +++ b/client_example/19_pool/max_pool2d_bwd.cpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using IndexDataType = int32_t; + +// We use pool3d to implement pool2d in this example +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +void TransformPool2dparamToPool3d(std::vector& input_lengths, + std::vector& window_lengths, + std::vector& output_lengths, + std::vector& input_stride, + std::vector& output_stride, + std::vector& indices_stride, + std::vector& window_strides, + std::vector& window_dilations, + std::vector& input_left_pads, + std::vector& input_right_pads, + std::vector& pooling_dims) +{ + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; +} + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCHW + std::vector in_length = {N, C, Hi, Wi}; + std::vector out_length = {N, C, Ho, Wo}; + std::vector window_spatial_lengths = {Y, X}; + std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector window_dilations = {window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + std::vector pooling_dims = {2, 3}; + + std::size_t in_tensor_size = N * C * Hi * Wi; + std::size_t out_tensor_size = N * C * Ho * Wo; + + // tensor layout = NHWC + std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; + std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + + TransformPool2dparamToPool3d(in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + // Generate index data from max pool forward + { + using MaxPoolFwdDeviceOp = + ck::tensor_operation::device::DevicePoolFwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolFwdDeviceOp>::GetInstances(); + + auto& op_ptr = op_ptrs[0]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + } + + // Run MaxPool bwd + using MaxPoolBwdDeviceOp = + ck::tensor_operation::device::DeviceMaxPoolBwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolBwdDeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = in_tensor_size * sizeof(DInDataType) + + out_tensor_size * sizeof(IndexDataType) + + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << "GB / s," + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/max_pool2d_fwd.cpp b/client_example/19_pool/max_pool2d_fwd.cpp similarity index 99% rename from client_example/19_pool_fwd/max_pool2d_fwd.cpp rename to client_example/19_pool/max_pool2d_fwd.cpp index 84b818a60fdec6a62ffe027285bdbec854721d5a..99087b47d3a4aa5cabe91a70b4ae794a43e7db6e 100644 --- a/client_example/19_pool_fwd/max_pool2d_fwd.cpp +++ b/client_example/19_pool/max_pool2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/19_pool_fwd/CMakeLists.txt b/client_example/19_pool_fwd/CMakeLists.txt deleted file mode 100644 index 13f9f73c83d55c801fdf4609e13f6b0813cb0c67..0000000000000000000000000000000000000000 --- a/client_example/19_pool_fwd/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) -target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) - -add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) -target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) \ No newline at end of file diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a60bada473b38f041bc409ef42dc4fabb7ecc57d..383c5d663088a931c74ff9a16aa653736243691c 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) -target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) + add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) + target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp index 94a57cd029a30d0fc73bae8cb8e43c1d475c95bc..5ace2e3056effb2f5a7acfb339e40e3b8590046d 100644 --- a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } @@ -191,6 +191,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a09921e50a32d10e60e69ad658aada49b7dc829c --- /dev/null +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa08f49e7dea42e2677bfa72d817b0620be806be --- /dev/null +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + d0_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + d0_dev_bufs.emplace_back(sizeof(D0DataType) * + f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back( + {a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + std::array{d0_dev_bufs[i].GetDeviceBuffer()}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + std::array{0}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 2); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1e1c39681ef03c24298a5ad4718c1e611703ac47 --- /dev/null +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -0,0 +1,13 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92311b484a7e91114deda2708eade9be1f0c1fe0 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = I8; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 1); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9dc5564fca742146cd475d7ee7bec97884e79c79 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3519e48aa61aaa3ed72d5a5f6ec4d37fe3013583 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 16); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d77f411a32d044c370170e69349e673c3fccb231 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = I8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_im2col_col2im/CMakeLists.txt b/client_example/22_im2col_col2im/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d938d8cfb3e57a17361de5dd56925dcc4a09f5af --- /dev/null +++ b/client_example/22_im2col_col2im/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_image_to_column image_to_column.cpp) +target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_other_operations) + +add_executable(client_column_to_image column_to_image.cpp) +target_link_libraries(client_column_to_image PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/22_im2col_col2im/column_to_image.cpp b/client_example/22_im2col_col2im/column_to_image.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ebe63198fbc7d943763289157dd84cd26cece3b --- /dev/null +++ b/client_example/22_im2col_col2im/column_to_image.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using ImageLayout = ck::tensor_layout::convolution::NHWGC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space + // However, CK's API only accepts lengths and strides with order of GNCHW. + // Hence, we need to adjust the order of strides. + std::array image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array gemm_strides{Y * X * C, G * Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Ho * Wo * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Hi * Wi * G * C); + + using namespace ck::conv_tensor_rearrange_op; + + using DeviceOp = ck::tensor_operation::device::DeviceConvTensorRearrange; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/22_im2col_col2im/image_to_column.cpp b/client_example/22_im2col_col2im/image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ceedd786261b691fd0eab23d1d78f5a56668bf0 --- /dev/null +++ b/client_example/22_im2col_col2im/image_to_column.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using ImageLayout = ck::tensor_layout::convolution::NHWGC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space + // However, CK's API only accepts lengths and strides with order of GNCHW. + // Hence, we need to adjust the order of strides. + std::array image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array gemm_strides{Y * X * C, G * Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C); + + using namespace ck::conv_tensor_rearrange_op; + + using DeviceOp = ck::tensor_operation::device::DeviceConvTensorRearrange; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/23_elementwise_transpose/CMakeLists.txt b/client_example/23_elementwise_transpose/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b2421d88136a9306a79d4b5dc9777670072354e --- /dev/null +++ b/client_example/23_elementwise_transpose/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp) +target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21602b19bda728520a8b7d57db5ff082d9c70f92 --- /dev/null +++ b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/transpose_3d.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + const int N = 16; + const int C = 8; + const int D = 8; + const int H = 8; + const int W = 8; + + std::vector ncdhw = {N, C, D, H, W}; + std::vector nchwd = {N, C, H, W, D}; + auto size = N * C * D * H * W; + + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, C, H, W, D + + SimpleDeviceMem a_dev_buf(sizeof(ADataType) * size); + SimpleDeviceMem b_dev_buf(sizeof(BDataType) * size); + + std::array input = {a_dev_buf.GetDeviceBuffer()}; + std::array output = {b_dev_buf.GetDeviceBuffer()}; + + using DeviceElementwisePermuteInstance = ck::tensor_operation::device:: + DeviceElementwise, ck::Tuple, PassThrough, 5>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceElementwisePermuteInstance>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + + sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..dc55250bfe78ea11e48ad75fd0d522a54e66e9ae --- /dev/null +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -0,0 +1,102 @@ +if(GPU_TARGETS MATCHES "gfx9") +# Fwd scaleadd scaleadd relu +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_conv_operations) +# Fwd scaleadd AB +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_conv_operations) +# Fwd bilinear +add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 + grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd convinvscale +add_executable(client_conv3d_fwd_convinvscale_fp8 + grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale + Bias +add_executable(client_conv3d_fwd_convscale_add_fp8 + grouped_convnd_fwd_convscale_add/conv3d_fwd_convscale_add_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_add_fp8 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale + ReLU +add_executable(client_conv3d_fwd_convscale_relu_fp8 + grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_relu_fp8 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale + ReLU + AMAX +add_executable(client_conv3d_fwd_convscale_relu_amax_fp8 + grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_relu_amax_fp8 + PRIVATE composable_kernel::device_conv_operations + composable_kernel::device_other_operations + composable_kernel::device_reduction_operations + utility) +# Fwd convscale + AMAX +add_executable(client_conv3d_fwd_convscale_amax_fp8 + grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_amax_fp8 + PRIVATE composable_kernel::device_conv_operations + composable_kernel::device_other_operations + composable_kernel::device_reduction_operations + utility) +# Fwd convscale +add_executable(client_conv3d_fwd_convscale_fp8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_bf8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_fp8_bf8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_bf8_fp8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) +# Bwd data bilinear +add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 + grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_data_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd weight bilinear +add_executable(client_grouped_convnd_bwd_weight_bilinear_residual_fp16 + grouped_convnd_bwd_weight_bilinear/grouped_conv_bwd_weight_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_weight_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd scale +add_executable(client_grouped_convnd_fwd_scale_fp16 + grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd data scale +add_executable(client_grouped_convnd_bwd_data_scale_fp16 + grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd weight scale +add_executable(client_grouped_convnd_bwd_weight_scale_fp16 + grouped_convnd_bwd_weight_scale/grouped_conv_bwd_weight_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_weight_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb106e8d8eef32e9034d2681fc471b8d6794bde3 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_data_bilinear() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple, + InDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {in.GetDeviceBuffer()}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {in_lengths}, + {in_strides}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + + 3 * G * N * Di * Hi * Wi * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {in.GetDeviceBuffer()}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {in_lengths}, + {in_strides}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_data_bilinear(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e53ecc6c99542ae7c44123ae98202789920c4c67 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_data_scale() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + + 3 * G * N * Di * Hi * Wi * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_data_scale(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_weight_bilinear/grouped_conv_bwd_weight_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_weight_bilinear/grouped_conv_bwd_weight_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5993ddf32bdda46c29e9535fdc05b23adf0d317 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_weight_bilinear/grouped_conv_bwd_weight_bilinear_residual_fp16.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_weight_bilinear() +{ + constexpr ck::index_t split_k = 2; + + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD, + InDataType, + WeiDataType, + OutDataType, + ck::Tuple, + PassThrough, + Bilinear, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in.GetDeviceBuffer()), + static_cast(wei.GetDeviceBuffer()), + static_cast(out.GetDeviceBuffer()), + {wei.GetDeviceBuffer()}, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + {wei_lengths}, + {wei_strides}, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + Bilinear{2.f, 2.f}, + PassThrough{}, + split_k); + + SimpleDeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + 3 * G * K * Z * Y * X * C; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + 2 * sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in.GetDeviceBuffer()), + static_cast(wei.GetDeviceBuffer()), + static_cast(out.GetDeviceBuffer()), + {wei.GetDeviceBuffer()}, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + {wei_lengths}, + {wei_strides}, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + Bilinear{2.f, 2.f}, + PassThrough{}, + split_k); + + SimpleDeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_weight_bilinear(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_weight_scale/grouped_conv_bwd_weight_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_weight_scale/grouped_conv_bwd_weight_scale_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c68e8b7602332520797dcc6d6fec5adbe1d74577 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_weight_scale/grouped_conv_bwd_weight_scale_fp16.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_weight_scale() +{ + constexpr ck::index_t split_k = 2; + + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD, + InDataType, + WeiDataType, + OutDataType, + ck::Tuple<>, + PassThrough, + Scale, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in.GetDeviceBuffer()), + static_cast(wei.GetDeviceBuffer()), + static_cast(out.GetDeviceBuffer()), + {}, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + {}, + {}, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + Scale{2.f}, + PassThrough{}, + split_k); + + SimpleDeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + G * K * Z * Y * X * C; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in.GetDeviceBuffer()), + static_cast(wei.GetDeviceBuffer()), + static_cast(out.GetDeviceBuffer()), + {}, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + {}, + {}, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + Scale{2.f}, + PassThrough{}, + split_k); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + SimpleDeviceMem workspace_buf(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_buf.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_weight_scale(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32ab481319449d473f933b1862e79c7f25cdeda1 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_bilinear() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {out.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths}, + {out_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {out.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths}, + {out_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd_bilinear(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7059e24d8e1d69524c5be27c5f427a9f93c82049 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convinvscale( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvInvscale, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvInvscale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvInvscale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..775ea99ecd520304fcf90d242e96d3cec330430a --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convinvscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..51eec5b1ab14dd3e66e026fdedcb8a9b41a58b3d --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvScale, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f901d08ab6eaf784dfd86e337cffe76725798c04 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..192c4fdcb904a4b65ba4e22fdbc7caa1dd1fb35f --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15d063c2f139d35c6824557cb615099c93513fd6 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b38225f2b916b3cd347d95190a2581b14d494684 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_add/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_add/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4bba13693c55a9987fcf9ff32bb8d61e12b50479 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_add/common.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/type.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_add.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleAdd = ck::tensor_operation::element_wise::ConvScaleAdd; +using F32 = float; +using BiasDataType = F32; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale_add( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + + namespace ctc = ck::tensor_layout::convolution; + static_assert(NumDimSpatial == 3 && ck::is_same_v && + ck::is_same_v && + ck::is_same_v, + "Unsupported configuration"); + + const ck::index_t G = in_lengths[4]; + const ck::index_t N = in_lengths[0]; + const ck::index_t K = wei_lengths[1]; + const ck::index_t C = in_lengths[5]; + const ck::index_t Z = wei_lengths[2]; + const ck::index_t Y = wei_lengths[3]; + const ck::index_t X = wei_lengths[4]; + const ck::index_t Di = in_lengths[1]; + const ck::index_t Hi = in_lengths[2]; + const ck::index_t Wi = in_lengths[3]; + const ck::index_t Do = out_lengths[1]; + const ck::index_t Ho = out_lengths[2]; + const ck::index_t Wo = out_lengths[3]; + + const std::size_t in_mem_size = sizeof(InDataType) * N * Di * Hi * Wi * G * C; + const std::size_t wei_mem_size = sizeof(WeiDataType) * G * K * Z * Y * X * C; + const std::size_t out_mem_size = sizeof(OutDataType) * N * Do * Ho * Wo * G * K; + const std::size_t bias_mem_size = sizeof(BiasDataType) * N * Do * Ho * Wo * G * K; + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + SimpleDeviceMem bias(bias_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // We have NDHWGC/GKZYXC/NDHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + const std::array input_lengths{G, N, C, Di, Hi, Wi}; + const std::array input_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + const std::array weights_lengths{G, K, C, Z, Y, X}; + const std::array weights_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + const std::array output_lengths{G, N, K, Do, Ho, Wo}; + const std::array output_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + const std::array bias_lengths{G, N, K, Do, Ho, Wo}; + const std::array bias_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + const std::array conv_filter_strides{1, 1, 1}; + const std::array conv_filter_dilations{1, 1, 1}; + const std::array input_left_pads{1, 1, 1}; + const std::array input_right_pads{1, 1, 1}; + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 elementwise Bias + std::size_t flop = GetFlops(output_lengths, weights_lengths, ds_size); + std::size_t num_bytes = in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + + sizeof(float) + out_mem_size + bias_mem_size; + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + ConvScaleAdd, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + input_lengths, + input_strides, + weights_lengths, + weights_strides, + std::array, 1>{ + {bias_lengths}}, + std::array, 1>{ + {bias_strides}}, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleAdd{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + input_lengths, + input_strides, + weights_lengths, + weights_strides, + std::array, 1>{ + {bias_lengths}}, + std::array, 1>{ + {bias_strides}}, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleAdd{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_add/conv3d_fwd_convscale_add_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_add/conv3d_fwd_convscale_add_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5324bb7144dba737a8325734760ba3b980602a47 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_add/conv3d_fwd_convscale_add_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale_add( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c78cacf266c24b8cfe66daf9cbc2350ac7af287f --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp @@ -0,0 +1,834 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/type.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp" +#include "ck/library/utility/host_tensor.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ScaleScaleRelu; +using ConvScale = ck::tensor_operation::element_wise::ScaleScalePass; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // 2 * G * N * K * C * * + + // + ds_size * => + // => * ( 2 * C * + ds_size) => + // => G * N * K * * (2 * C * + + // ds_size) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * K * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (ds_size + static_cast(2) * C * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>())); +} + +template +std::size_t GetTensorSize(const std::array& lengths) +{ + + return std::accumulate(std::begin(lengths), + std::end(lengths), + static_cast(1), + std::multiplies()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * GetTensorSize(input_lengths); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * GetTensorSize(weights_lengths); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * GetTensorSize(output_lengths); +} + +template +bool ConvolutionScale(SimpleDeviceMem& in, + SimpleDeviceMem& wei, + SimpleDeviceMem& out, + ConvElementOp elementwise_op, + const std::array& in_lengths, + const std::array& in_strides, + const std::array& wei_lengths, + const std::array& wei_strides, + const std::array& out_lengths, + const std::array& out_strides); + +template +bool TensorScaleConvert(SimpleDeviceMem& in, + SimpleDeviceMem& out, + float scale_out, + const std::array& lengths, + const std::array& strides); + +template +bool TensorFullReduction(SimpleDeviceMem& tensor, + SimpleDeviceMem& out_amax, + const std::array& lengths, + const std::array& strides); + +template +bool run_grouped_conv_fwd_convscale_reduce( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + + namespace ctc = ck::tensor_layout::convolution; + static_assert(NumDimSpatial == 3 && ck::is_same_v && + ck::is_same_v && + ck::is_same_v, + "Unsupported configuration"); + + const ck::index_t G = in_lengths[4]; + const ck::index_t N = in_lengths[0]; + const ck::index_t K = wei_lengths[1]; + const ck::index_t C = in_lengths[5]; + const ck::index_t Z = wei_lengths[2]; + const ck::index_t Y = wei_lengths[3]; + const ck::index_t X = wei_lengths[4]; + const ck::index_t Di = in_lengths[1]; + const ck::index_t Hi = in_lengths[2]; + const ck::index_t Wi = in_lengths[3]; + const ck::index_t Do = out_lengths[1]; + const ck::index_t Ho = out_lengths[2]; + const ck::index_t Wo = out_lengths[3]; + + const std::size_t in_mem_size = sizeof(InDataType) * N * Di * Hi * Wi * G * C; + const std::size_t wei_mem_size = sizeof(WeiDataType) * G * K * Z * Y * X * C; + const std::size_t conv_out_mem_size = sizeof(ConvOutDataType) * N * Do * Ho * Wo * G * K; + const std::size_t out_mem_size = sizeof(OutDataType) * N * Do * Ho * Wo * G * K; + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem conv_out(conv_out_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // We have NDHWGC/GKZYXC/NDHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + const std::array input_lengths{G, N, C, Di, Hi, Wi}; + const std::array input_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + const std::array weights_lengths{G, K, C, Z, Y, X}; + const std::array weights_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + const std::array output_lengths{G, N, K, Do, Ho, Wo}; + const std::array output_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + /* + * FP8 Convolution with Scaling + */ + std::cout << "\n\nConvolution with scale Benchmarking:" << std::endl; + auto elementwise_op = ConvElementOp{ck::tensor_operation::element_wise::Scale{scale_in}, + ck::tensor_operation::element_wise::Scale{scale_wei}, + {}}; + auto conv_ok = ConvolutionScale(in, + wei, + conv_out, + elementwise_op, + input_lengths, + input_strides, + weights_lengths, + weights_strides, + output_lengths, + output_strides); + + if(!conv_ok) + return false; + + /* + * Scale with output weight and convert to FP8 + */ + std::cout << "\n\nElement-wise scale + convert Benchmarking:" << std::endl; + auto elem_wise_ok = TensorScaleConvert( + conv_out, out, scale_out, output_lengths, output_strides); + + if(!elem_wise_ok) + return false; + + /* + * Compute AMAX + */ + std::cout << "\n\nAMAX Benchmarking:" << std::endl; + SimpleDeviceMem amax_device(sizeof(ConvOutDataType)); + auto reduction_ok = + TensorFullReduction(conv_out, amax_device, output_lengths, output_strides); + + if(!reduction_ok) + return false; + + return true; +} + +template +bool ConvolutionScale(SimpleDeviceMem& in, + SimpleDeviceMem& wei, + SimpleDeviceMem& out, + ConvElementOp elementwise_op, + const std::array& in_lengths, + const std::array& in_strides, + const std::array& wei_lengths, + const std::array& wei_strides, + const std::array& out_lengths, + const std::array& out_strides) +{ + + const std::array conv_filter_strides{1, 1, 1}; + const std::array conv_filter_dilations{1, 1, 1}; + const std::array input_left_pads{1, 1, 1}; + const std::array input_right_pads{1, 1, 1}; + + const auto in_mem_size = GetInputByte(in_lengths); + const auto wei_mem_size = GetWeightByte(wei_lengths); + const auto out_mem_size = GetOutputByte(out_lengths); + + std::size_t ds_size = 2; // 2 element-wise scale multipliers + if constexpr(ck::is_same_v) + { + ds_size += 1; // +1 element-wise relu + } + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + out_mem_size; + + using ConvDeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvElementOp, + AComputeType, + BComputeType>; + // get device op instances + const auto conv_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + ConvDeviceOp>::GetInstances(); + + std::cout << "found " << conv_ptrs.size() << " instances" << std::endl; + + std::string conv_best_op_name; + int conv_best_op_id = -1; + float conv_best_avg_time = std::numeric_limits::max(); + float conv_best_gb_per_sec = 0; + float conv_best_tflops = 0; + + // profile device operation instances + std::cout << "Run all convolution instances and do timing" << std::endl; + + for(int i = 0; i < conv_ptrs.size(); ++i) + { + auto& op_ptr = conv_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + elementwise_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > conv_best_tflops) + { + conv_best_op_id = i; + conv_best_op_name = op_name; + conv_best_avg_time = avg_time; + conv_best_gb_per_sec = gb_per_sec; + conv_best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(conv_best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << conv_best_avg_time << " ms, " << conv_best_tflops + << " TFlops, " << conv_best_gb_per_sec << " GB/s, " << conv_best_op_name << std::endl; + + // run the best instance + { + auto& op_ptr = conv_ptrs[conv_best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + elementwise_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return true; +} + +template +bool TensorScaleConvert(SimpleDeviceMem& in, + SimpleDeviceMem& out, + float scale_out, + const std::array& lengths, + const std::array& strides) +{ + + const auto tensor_size = GetTensorSize(lengths); + + const std::size_t in_mem_size = sizeof(InDataType) * tensor_size; + const std::size_t out_mem_size = sizeof(OutDataType) * tensor_size; + + std::size_t flop = 2 * tensor_size; // element-wise scale + convert + + std::size_t bytes = + in_mem_size + sizeof(float) + out_mem_size; // read from in, scale, write to out + + using DeviceScaleConvert = + ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + ck::tensor_operation::element_wise::Scale, + NumDimSpatial + NumNonSpatialDim>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceScaleConvert>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all DeviceScaleConvert instances and do timing" << std::endl; + + auto scale_convert = ck::tensor_operation::element_wise::Scale{scale_out}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, + {strides}, + {strides}, + {in.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, + scale_convert); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance found." << std::endl; + return false; + } + else + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, + {strides}, + {strides}, + {in.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, + scale_convert); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return true; +} + +template +bool TensorFullReduction(SimpleDeviceMem& tensor, + SimpleDeviceMem& out_amax, + const std::array& lengths, + const std::array& strides) +{ + const auto spatial_dim_size = std::accumulate(std::next(std::begin(lengths), NumNonSpatialDim), + std::end(lengths), + static_cast(1), + std::multiplies<>()); + const auto tensor_size = GetTensorSize(lengths); + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + // Get the reduction operation + using ReduceOperation = typename ck::reduce_binary_operator::opType; + using InElementwiseOperation = + typename ck::reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename ck::reduce_unary_operator::AccElementwiseOperation; + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + std::tie(in_elementwise_op, acc_elementwise_op) = + ck::reduce_unary_operator::GetElementwiseOperator( + static_cast(tensor_size)); + + std::array reduce_out_lengths{1}; + std::array reduce_out_strides{1}; + + SimpleDeviceMem partial_reduce_tensor(sizeof(OutDataType) * spatial_dim_size); + std::array reduce_part_lengths; + std::copy(std::next(std::begin(lengths), NumNonSpatialDim), + std::end(lengths), + std::begin(reduce_part_lengths)); + std::array reduce_part_strides; + copy(HostTensorDescriptor(reduce_part_lengths).GetStrides(), reduce_part_strides); + + { + std::cout << "\nReduction of nonspatial dimensions:" << std::endl; + using DeviceOp = + ck::tensor_operation::device::DeviceReduce; // OutputIndex + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + std::array reduce_dims; + std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // 0,..., NumNonSpatialDim-1 + + ck::index_t num_in_elements = tensor_size; + ck::index_t num_out_elements = spatial_dim_size; + + // profile device operation instances + std::cout << "Run partial reduction and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, + strides, + reduce_part_lengths, + reduce_part_strides, + reduce_dims, + 1.0, + 0.0, + tensor.GetDeviceBuffer(), + nullptr, + partial_reduce_tensor.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + std::size_t num_bytes = + num_in_elements * sizeof(InDataType) + num_out_elements * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(ave_time < best_ave_time) + { + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance found." << std::endl; + return false; + } + else + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best instance + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, + strides, + reduce_part_lengths, + reduce_part_strides, + reduce_dims, + 1.0, + 0.0, + tensor.GetDeviceBuffer(), + nullptr, + partial_reduce_tensor.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + } + + { + std::cout << "\nReduction of spatial dimensions:" << std::endl; + using DeviceOp = ck::tensor_operation::device::DeviceReduce; // OutputIndex + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + std::array reduce_dims; + std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // 0,..., NumDimSpatial-1 + + ck::index_t num_in_elements = spatial_dim_size; + ck::index_t num_out_elements = 1; + + // profile device operation instances + std::cout << "Run final reduction and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(reduce_part_lengths, + reduce_part_strides, + reduce_out_lengths, + reduce_out_strides, + reduce_dims, + 1.0, + 0.0, + partial_reduce_tensor.GetDeviceBuffer(), + nullptr, + out_amax.GetDeviceBuffer(), + nullptr, + PassThrough{}, + acc_elementwise_op); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + num_in_elements * sizeof(OutDataType) + num_out_elements * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(ave_time < best_ave_time) + { + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance found." << std::endl; + return false; + } + else + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best instance + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(reduce_part_lengths, + reduce_part_strides, + reduce_out_lengths, + reduce_out_strides, + reduce_dims, + 1.0, + 0.0, + partial_reduce_tensor.GetDeviceBuffer(), + nullptr, + out_amax.GetDeviceBuffer(), + nullptr, + PassThrough{}, + acc_elementwise_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + } + + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c0299b84134ab28c7b406c753e468a0abf42363 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using ConvOutDataType = float; // data type of convolution result +using OutDataType = ck::f8_t; // data type of final result +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using ConvElementOp = ConvScale; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +constexpr auto ReduceOpId = ck::ReduceTensorOp::AMAX; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale_reduce( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..182642c030db2a5cb679a94a44ded65f51168e8d --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using ConvOutDataType = float; // data type of convolution result +using OutDataType = ck::f8_t; // data type of final result +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using ConvElementOp = ConvScaleRelu; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +constexpr auto ReduceOpId = ck::ReduceTensorOp::AMAX; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale_reduce( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee188429b4cb4750dc7776bd199d56e005788a1d --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale_relu( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 elementwise Relu + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvScaleRelu, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleRelu{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleRelu{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4003dc7c86c65369b9bc09225ca227d81a7ffa80 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale_relu( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..11e69f5bb29f61cfc684b58be7ac169a2cacc1ec --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scale() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd_scale(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc new file mode 100644 index 0000000000000000000000000000000000000000..3f6f7b07737cecc4122c9dea62c87b19aafd5400 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scaleadd_ab() +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + + constexpr float scale = 1.5f; + + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + using InputDtype = ck::tuple_element_t<0, InDataType>; + using InputBiasDtype = ck::tuple_element_t<1, InDataType>; + using WeightDtype = ck::tuple_element_t<0, WeiDataType>; + using WeightBiasDtype = ck::tuple_element_t<1, WeiDataType>; + + SimpleDeviceMem in(sizeof(InputDtype) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem in_bias(sizeof(InputBiasDtype) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeightDtype) * G * K * Z * Y * X * C); + SimpleDeviceMem wei_bias(sizeof(WeightBiasDtype) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + ScaleAdd, + ScaleAdd, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::array as = {in.GetDeviceBuffer(), in_bias.GetDeviceBuffer()}; + std::array bs = {wei.GetDeviceBuffer(), wei_bias.GetDeviceBuffer()}; + std::array ds{}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + ScaleAdd{scale}, + ScaleAdd{scale}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Z * Y * X + + N * Di * Hi * Wi * G * C + G * K * Z * Y * X * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * N * Di * Hi * Wi * G * C + + 2 * sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * N * Do * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + ScaleAdd{scale}, + ScaleAdd{scale}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fef3f7428c2149cd11892a987c320eeba93135cc --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = ck::bhalf_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43db279191331edf969bdc15e526f4694c1650ec --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = ck::half_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cccec47701fbf52b9bc9ab2f37c49e29387224f9 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = float; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28674c8abeb702211b699ca5d1c864837bb54990 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = int8_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc new file mode 100644 index 0000000000000000000000000000000000000000..4e3cf69637c55178c4a7007891bd918c1cacf5df --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scaleadd_scaleadd_relu() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K); + SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, std::tuple_element_t<1, DDataTypes>>, + OutDataType, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {d0.GetDeviceBuffer(), d1.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ScaleAddScaleAddRelu{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 2 * N * Ho * Wo * G * K; + std::size_t num_bytes = + sizeof(InDataType) * N * Hi * Wi * G * C + sizeof(WeiDataType) * G * K * Y * X * C + + (sizeof(OutDataType) + sizeof(std::tuple_element_t<0, DDataTypes>) + + sizeof(std::tuple_element_t<1, DDataTypes>)) * + N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {d0.GetDeviceBuffer(), d1.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ScaleAddScaleAddRelu{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a32c4f74277250de93f2c06d6419ccc39b13514 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3e91072b34104c85aab41363e06d34739dd4831 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7ed96b6a04a67b85e366e2a1405561ae65df925 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9959664d2a110483f26340e25224cee16997c5e9 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b1e9d20bfde89063c3b55c86dff2c6517343f702 --- /dev/null +++ b/client_example/25_wrapper/CMakeLists.txt @@ -0,0 +1,10 @@ +add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) +target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) +add_executable(client_wrapper_img2col wrapper_img2col.cpp) +target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) + target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) + add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) + target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations) +endif() diff --git a/client_example/25_wrapper/README.md b/client_example/25_wrapper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eba3de017f41bf7cc7de8a4f2cd581dfd81c0093 --- /dev/null +++ b/client_example/25_wrapper/README.md @@ -0,0 +1,177 @@ +# Composable Kernel wrapper GEMM tutorial + +This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) +wrapper. We present the base version of GEMM without most of the available optimizations; however, +it's worth noting that CK has kernels with different optimizations. + +To implement these optimizations, you can use the CK wrapper or directly use available instances in +CK. You can also refer to the +[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), +that uses CK wrapper based on the +[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. + +The kernel definition should look similar to: + +```cpp +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +``` + +We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass +selected lengths of processed data through each block (`tile_shape`) and thread layout +(`thread_layout`). For compilation time parameters, we define the data type, +[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp) +and scalar per vector value during copy. + +Step 1: Create layouts for global and LDS memory. + +```cpp + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); +``` + +We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or +`KPerBlock`. + +Step 2: Create tensors for global and LDS memory. + +```cpp + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + + // Allocate LDS memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); +``` + +We must specify parameters for copy and convert block indexes to tuple: + +```cpp + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; +``` + +We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also +define and clear an output register (`c_vgpr_reg`) for the accumulation. + +```cpp + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); +``` + +We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and +`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output +and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only +generic functions for the CK wrapper. + +Step 3: Create the compute loop. + +```cpp + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to LDS. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise GEMM. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); +``` + +Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block), +data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM +operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`. + +The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread): + +```cpp + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +``` + +If you want to dive deep into the details, you can find the entire example +[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp). diff --git a/client_example/25_wrapper/tensor_transform_using_wrapper.cpp b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b25d85e2d34894f8acade03afc07cd7fa85892f --- /dev/null +++ b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +#include "ck/wrapper/layout.hpp" + +using DataType = int; + +template +void Print1d(const Layout& layout) +{ + std::cout << "Print1d" << std::endl; + for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++) + { + std::cout << layout(ck::make_tuple(w)) << " "; + } + std::cout << std::endl; +} + +template +void Print2d(const Layout& layout) +{ + std::cout << "Print2d" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } +} + +// Print in (x,y),z pattern +template +void Print3dCustom(const Layout& layout) +{ + std::cout << "Print3dCustom" << std::endl; + for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++) + { + for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +int main() +{ + // Layout traverse in row-major + std::cout << "Note: Layout traverse in column-major" << std::endl; + // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) + // (dims:4,8 strides:1,4) + const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); + const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8); + std::cout << "dims:4,8 strides:1,4" << std::endl; + Print2d(layout_4x8_s1x4); + using Cord1x1Type = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()(); + std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) + // dims:4,(2,4) strides:2,(1,8) + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + Print2d(layout_4x2x4_s2x1x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:(2,2),(2,4) strides:((1,4),(2,8) + const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<4>{})); + const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); + static const auto layout_2x2x2x4_s1x4x2x8 = + ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + + std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; + Print2d(layout_2x2x2x4_s1x4x2x8); + Print3dCustom(layout_2x2x2x4_s1x4x2x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:((2,2),2),4 strides:((1,4),2),8 + // Transform to 2d + const auto shape_2x2x2x4_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}), + ck::Number<4>{}); + const auto strides_s1x4x2x8_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), + ck::Number<8>{}); + static const auto layout_2x2x2x4_s1x4x2x8_nested = + ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + + std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; + Print1d(layout_2x2x2x4_s1x4x2x8_nested); + Print2d(layout_2x2x2x4_s1x4x2x8_nested); + Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested); + + return 0; +} diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23245dd1889e3890e3c474075d06f661c915dee6 --- /dev/null +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + // Allocate lds memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + // Create tile and partition for C. Use specific function for blockwise_gemm to assign the + // appropriate partitions. + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); + + // Iterate over K with KPerBlock step. + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to lds. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); + // Copy vgpr results to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +#endif diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ceccc5eb8fe965f025f0569d4fc6bb6d31af3922 --- /dev/null +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +static constexpr ck::index_t NumDimSpatial = 3; +using DataType = float; +using InputLayout = ck::tensor_layout::convolution::NDHWGC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ +DeviceImageToColumnPad0(InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + // grid layout (dim1, dim0) + const auto block_idxs = + ck::make_tuple(static_cast(blockIdx.y), static_cast(blockIdx.x)); + + // Get local tiles for global memory + auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs); + auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs); + + // Get partition per thread + const auto input_local_partition = + ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x); + auto output_local_partition = + ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x); + + // Perform copy + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + constexpr ck::index_t scalar_per_vector = 4; + ck::wrapper::copy(input_local_partition, + output_local_partition); +} + +void PerformImageToColumnPad0(const ck::index_t G, + const ck::index_t N, + const ck::index_t Di, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Do, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t C, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + std::array filter_strides, + std::array filter_dilations) +{ + const ck::index_t ZYXC = Z * Y * X * C; + const ck::index_t GC = G * C; + + // shape: (G, (Wo, Ho, Do, N)), (C, X, Y, Z)) + const auto shape = ck::make_tuple(ck::make_tuple(G, ck::make_tuple(Wo, Ho, Do, N)), + ck::make_tuple(C, X, Y, Z)); + const auto in_strides = + ck::make_tuple(ck::make_tuple(C, + ck::make_tuple(filter_strides[2] * GC, + filter_strides[1] * Wi * GC, + filter_strides[0] * Hi * Wi * GC, + Di * Hi * Wi * GC)), + ck::make_tuple(1, + filter_dilations[2] * GC, + filter_dilations[1] * Wi * GC, + filter_dilations[0] * Hi * Wi * GC)); + const auto in_layout = ck::wrapper::make_layout(shape, in_strides); + + const auto out_strides = ck::make_tuple( + ck::make_tuple( + ZYXC, + ck::make_tuple(ZYXC * G, Wo * ZYXC * G, Ho * Wo * ZYXC * G, Do * Ho * Wo * ZYXC * G)), + ck::make_tuple(1, C, X * C, Y * X * C)); + const auto out_layout = ck::wrapper::make_layout(shape, out_strides); + + const ck::index_t input_size = N * Di * Hi * Wi * GC; + // Global memory buffers + SimpleDeviceMem in_buf(input_size * sizeof(DataType)); + SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); + + // User can choose appropriate number of threads and sizes per block + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}), + ck::make_tuple(ck::Number<16>{}, ck::Number<1>{})); + // This example doesn't support padding, user should select tile sizes + // which are divisible by the shape. + const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); + + // Create buffers for global memory + auto input_tensor_global = ck::wrapper::make_tensor( + static_cast(in_buf.GetDeviceBuffer()), in_layout); + auto output_tensor_global = ck::wrapper::make_tensor( + static_cast(out_buf.GetDeviceBuffer()), out_layout); + + // grid layout (dim1, dim0) + const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), + ck::wrapper::size<1>(tile_shape)); + const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), + ck::wrapper::size<0>(tile_shape)); + + const auto kernel = DeviceImageToColumnPad0; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + input_tensor_global, + output_tensor_global, + tile_shape, + thread_layout); + + std::size_t num_btype = G * N * Do * Ho * Wo * ZYXC * 2 * sizeof(DataType); + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << std::endl; +} + +int main(int argc, char* argv[]) +{ + constexpr ck::index_t G = 4; // number of groups + constexpr ck::index_t N = 32; // batch + constexpr ck::index_t C = 64; // input channel (per group) + constexpr ck::index_t Z = 3; // filter D + constexpr ck::index_t Y = 3; // filter H + constexpr ck::index_t X = 3; // filter W + constexpr ck::index_t Di = 9; // input D + constexpr ck::index_t Hi = 9; // input H + constexpr ck::index_t Wi = 7; // input W + constexpr ck::index_t Do = 7; // output D + constexpr ck::index_t Ho = 7; // output H + constexpr ck::index_t Wo = 5; // output W + PerformImageToColumnPad0(G, + N, + Di, + Hi, + Wi, + Do, + Ho, + Wo, + C, + Z, + Y, + X, + {1, 1, 1} /*filter_strides*/, + {1, 1, 1} /*filter_dilations*/); + return 0; +} diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..31e20342df31c7f228ef0f193da7f32f0f033550 --- /dev/null +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + // Create layouts for global memory + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Apply padding + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + // Create tensors for global memory + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + // Create layouts and tensors for lds memory. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + // Create tile and partition for C global memory. Use specific gemm + // functions to get appropriate layouts. + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Define and clear c vgpr register + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + // Local partitions for lds memory + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + // Lamda to slice tensor, then create local tile and partition + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + // Copy first values to lds + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + // Pipeline loop + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + // Skip if only tile should be processed + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + // Copy data to A vgpr. + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + // Synchronize. + ck::block_sync_lds(); + // Copy data to B vgpr. + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + // Perform gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Synchronize + ck::block_sync_lds(); + // Copy data to A and B lds tiles. + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + // Handle tail. + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Store data from C vgpr to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +#endif diff --git a/client_example/30_gemm_bf16Aint8B/CMakeLists.txt b/client_example/30_gemm_bf16Aint8B/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5cfcb68e10cbfab5d1450729e1b55007ead53437 --- /dev/null +++ b/client_example/30_gemm_bf16Aint8B/CMakeLists.txt @@ -0,0 +1,16 @@ +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES)) + add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp) + target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_gelu_bf16_i8_bf16 gemm_xdl_gelu_bf16_i8.cpp) + target_link_libraries(client_gemm_gelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_bf16_i8_bf16 gemm_xdl_bf16_i8.cpp) + target_link_libraries(client_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_multiply_bf16_i8_bf16 gemm_xdl_multiply_bf16_i8.cpp) + target_link_libraries(client_gemm_multiply_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/30_gemm_bf16Aint8B/gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/30_gemm_bf16Aint8B/gemm_bias_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c47e42931ed0dc9a7c242fbe992b4061252e8e09 --- /dev/null +++ b/client_example/30_gemm_bf16Aint8B/gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD, ELayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_bf16Aint8B/gemm_bias_xdl_bf16_i8.cpp b/client_example/30_gemm_bf16Aint8B/gemm_bias_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1d449ef8ca59c6af612a5a3a4ff2ff8748000f1 --- /dev/null +++ b/client_example/30_gemm_bf16Aint8B/gemm_bias_xdl_bf16_i8.cpp @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = M; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD, ELayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_bf16Aint8B/gemm_xdl_bf16_i8.cpp b/client_example/30_gemm_bf16Aint8B/gemm_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f1b7eddb6070b28b7c09e98ad038895e8c1e7c0 --- /dev/null +++ b/client_example/30_gemm_bf16Aint8B/gemm_xdl_bf16_i8.cpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_bf16Aint8B/gemm_xdl_gelu_bf16_i8.cpp b/client_example/30_gemm_bf16Aint8B/gemm_xdl_gelu_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc4c34ae7f3e476001775d9bf4bf38e5423991b9 --- /dev/null +++ b/client_example/30_gemm_bf16Aint8B/gemm_xdl_gelu_bf16_i8.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_bf16Aint8B/gemm_xdl_multiply_bf16_i8.cpp b/client_example/30_gemm_bf16Aint8B/gemm_xdl_multiply_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d056a782943eabf5b9bbe491d53dbce734c8170e --- /dev/null +++ b/client_example/30_gemm_bf16Aint8B/gemm_xdl_multiply_bf16_i8.cpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Multiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 1; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer()}, + std::array{b1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt b/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c3483ef5db63e517acdb044cc6cc5d977f2c643a --- /dev/null +++ b/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt @@ -0,0 +1,16 @@ +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES)) + add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_fastgelu_bf16_i8_bf16 grouped_gemm_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_multiply_bf16_i8_bf16 grouped_gemm_multiply_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_multiply_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_multiply_bias_fastgelu_bf16_i8_bf16 grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_multiply_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_bf16_i8_bf16 grouped_gemm_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bf748cdbb0bcbf33c5c841eaca9c6ab36eec172 --- /dev/null +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetElementwiseOps( + argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f300583d13fc44ba313bbfcdf3c9cc8fa849dfe4 --- /dev/null +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,284 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetElementwiseOps( + argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b284c74d4a75cf3634b77e703d369df44ba8098 --- /dev/null +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using B0DataType = I8; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Row; +using B1Layout = B0Layout; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyAddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumDTensor = 2; + + using GroupedGemmKernelArgument = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + gemm_descs.push_back({problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {0, 0}}); + + grouped_gemm_kernel_args_.push_back( + {a0_tensors_device[i]->GetDeviceBuffer(), + b0_tensors_device[i]->GetDeviceBuffer(), + {b1_tensors_device[i]->GetDeviceBuffer(), d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {0, 0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, 0, 20, 50}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(1 + rand() % 1024); + problem_size.Ns.push_back(6144); + problem_size.Ks.push_back(4096); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + std::cout << " M = " << problem_size.Ms[i] << " N = " << problem_size.Ns[i] << " K " + << problem_size.Ks[i] << std::endl; + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6cc83e06f68555f84d642219e3f49aafd627fa66 --- /dev/null +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using B0DataType = I8; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Row; +using B1Layout = B0Layout; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Multiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back(std::make_unique( + sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + gemm_descs.push_back({problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {0}}); + + grouped_gemm_kernel_args_.push_back({a0_tensors_device[i]->GetDeviceBuffer(), + b0_tensors_device[i]->GetDeviceBuffer(), + {b1_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, 0, 20, 50}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(1 + rand() % 1024); + problem_size.Ns.push_back(4096); + problem_size.Ks.push_back(4096); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + std::cout << " M = " << problem_size.Ms[i] << " N = " << problem_size.Ns[i] << " K " + << problem_size.Ks[i] << std::endl; + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_xdl_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..557dea767688f990a0713807107607e9a895fdee --- /dev/null +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_xdl_bf16_i8.cpp @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +#include "ck/host_utility/hip_check_error.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetElementwiseOps( + argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, 0, 20, 50}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(1 + rand() % 1024); + problem_size.Ns.push_back(4096); + problem_size.Ks.push_back(4096); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + std::cout << " M = " << problem_size.Ms[i] << " N = " << problem_size.Ns[i] << " K " + << problem_size.Ks[i] << std::endl; + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index eb793b3cbd6e693fd8cf94ef3489c6ac3eb915c7..acb57d7bb045e8d7d917fc42ba5ea1e04e4cb252 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -6,56 +6,74 @@ if (DTYPES) add_definitions(-DDTYPES) if (DTYPES MATCHES "int8") add_definitions(-DCK_ENABLE_INT8) - if(NOT DEFINED ${CK_ENABLE_INT8}) - set(CK_ENABLE_INT8 "ON") - endif() + set(CK_ENABLE_INT8 "ON") endif() if (DTYPES MATCHES "fp8") add_definitions(-DCK_ENABLE_FP8) - if(NOT DEFINED ${CK_ENABLE_FP8}) - set(CK_ENABLE_FP8 "ON") - endif() + set(CK_ENABLE_FP8 "ON") + endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) - if(NOT DEFINED ${CK_ENABLE_FP16}) - set(CK_ENABLE_FP16 "ON") - endif() + set(CK_ENABLE_FP16 "ON") endif() if (DTYPES MATCHES "fp32") add_definitions(-DCK_ENABLE_FP32) - if(NOT DEFINED ${CK_ENABLE_FP32}) - set(CK_ENABLE_FP32 "ON") - endif() + set(CK_ENABLE_FP32 "ON") endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) - if(NOT DEFINED ${CK_ENABLE_FP64}) - set(CK_ENABLE_FP64 "ON") - endif() + set(CK_ENABLE_FP64 "ON") endif() if (DTYPES MATCHES "bf16") add_definitions(-DCK_ENABLE_BF16) - if(NOT DEFINED ${CK_ENABLE_BF16}) - set(CK_ENABLE_BF16 "ON") - endif() + set(CK_ENABLE_BF16 "ON") endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) - if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES}) - set(CK_ENABLE_ALL_DTYPES "ON") + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + set(CK_ENABLE_INT8 "ON") + set(CK_ENABLE_FP16 "ON") + set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_FP64 "ON") + set(CK_ENABLE_BF16 "ON") + if (GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8) + set(CK_ENABLE_FP8 "ON") + set(CK_ENABLE_BF8 "ON") endif() endif() -find_package(composable_kernel COMPONENTS device_operations) +if (GPU_TARGETS) + if (GPU_TARGETS MATCHES "gfx9") + add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") + endif() + if (GPU_TARGETS MATCHES "gfx11") + add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") + endif() +else() + add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") +endif() + +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +if(GPU_TARGETS MATCHES "gfx9") + find_package(composable_kernel COMPONENTS device_contraction_operations) +endif() find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build")) + IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build") + AND (NOT "${subdir}" MATCHES ".vscode")) add_subdirectory(${subdir}) ENDIF() ENDFOREACH() diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake index 01b348c4583f5d18b202410a694ee6a8e460580a..cf77991a64a7543ac7b97d604ea3b9e72a626c62 100644 --- a/cmake/ClangTidy.cmake +++ b/cmake/ClangTidy.cmake @@ -149,7 +149,7 @@ function(clang_tidy_check TARGET) add_custom_target(${tidy_target} # for some targets clang-tidy not able to get information from .clang-tidy DEPENDS ${SOURCE} - COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" + COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_PLATFORM_AMD__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." ) diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake index 2e3669fcdf48fde87c23a84c2493ce26aecda7af..c91308b5bb6db3cac55628c2bf5123f261cc838d 100644 --- a/cmake/DoxygenDoc.cmake +++ b/cmake/DoxygenDoc.cmake @@ -309,6 +309,8 @@ XML_OUTPUT XML_PROGRAMLISTING ) +set(WARN_AS_ERROR YES) + set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") function(add_doxygen_doc) diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake index f3ef595c995d0336359a4e9b12583ecf8b6ed324..8023ef014a4514af5b613fc36be002673e83806d 100644 --- a/cmake/Embed.cmake +++ b/cmake/Embed.cmake @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 66139cc710992df07151b85537c28f483e2ef452..93fd306e98af3bf86fcc6f0f213d029f6f3c4a26 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -2,7 +2,7 @@ # # MIT License # -# Copyright (c) 2017 Advanced Micro Devices, Inc. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -66,10 +66,11 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt + -Wno-unused-template ) if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") list(APPEND CMAKE_COMPILER_WARNINGS @@ -94,6 +95,8 @@ else() -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage + -Wno-unused-lambda-capture + -Wno-nvcc-compat ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/cmake/getopt.cmake b/cmake/getopt.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dd985ff472287797a69550d785459251e01d62e5 --- /dev/null +++ b/cmake/getopt.cmake @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +add_library(getopt::getopt INTERFACE IMPORTED GLOBAL) + +if(WIN32) + include(FetchContent) + + FetchContent_Declare( + getopt + GIT_REPOSITORY https://github.com/apwojcik/getopt.git + GIT_TAG main + SYSTEM + ) + + set(__build_shared_libs ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") + + FetchContent_MakeAvailable(getopt) + + # Restore the old value of BUILD_SHARED_LIBS + set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) + + FetchContent_GetProperties(getopt) + + target_link_libraries(getopt::getopt INTERFACE wingetopt) + target_include_directories(getopt::getopt INTERFACE ${getopt_SOURCE_DIR}/src) +endif() \ No newline at end of file diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake deleted file mode 100644 index d6577ac33e71c575b2aa441da30838f58157f0b4..0000000000000000000000000000000000000000 --- a/cmake/googletest.cmake +++ /dev/null @@ -1,50 +0,0 @@ -include(FetchContent) - -set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") - -if(GOOGLETEST_DIR) - set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") -endif() - -message(STATUS "Fetching GoogleTest") - -list(APPEND GTEST_CMAKE_CXX_FLAGS - -Wno-undef - -Wno-reserved-identifier - -Wno-global-constructors - -Wno-missing-noreturn - -Wno-disabled-macro-expansion - -Wno-used-but-marked-unused - -Wno-switch-enum - -Wno-zero-as-null-pointer-constant - -Wno-unused-member-function - -Wno-comma - -Wno-old-style-cast - -Wno-deprecated - -Wno-unsafe-buffer-usage -) -message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") - -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG b85864c64758dec007208e56af933fc3f52044ee -) - -# Will be necessary for windows build -# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -FetchContent_GetProperties(googletest) -if(NOT googletest_POPULATED) - FetchContent_Populate(googletest) - add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) -endif() - -target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) - -set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0915f534112f441fbf77a93dbcdb58f2bb7f141c --- /dev/null +++ b/cmake/gtest.cmake @@ -0,0 +1,70 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +FetchContent_Declare( + GTest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 +) + +# Suppress ROCMChecks WARNING on GoogleTests +set(ROCM_DISABLE_CHECKS FALSE) +macro(rocm_check_toolchain_var var access value list_file) + if(NOT ROCM_DISABLE_CHECKS) + _rocm_check_toolchain_var("${var}" "${access}" "${value}" "${list_file}") + endif() +endmacro() + +if(WIN32) + set(gtest_force_shared_crt ON CACHE_INTERNAL "") +endif() + +set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(INSTALL_GTEST OFF CACHE INTERNAL "") + +# Store the current value of BUILD_SHARED_LIBS +set(__build_shared_libs ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") + +set(ROCM_DISABLE_CHECKS TRUE) +FetchContent_MakeAvailable(GTest) +set(ROCM_DISABLE_CHECKS FALSE) + +# Restore the old value of BUILD_SHARED_LIBS +set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) + +set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(INSTALL_GTEST OFF CACHE INTERNAL "") + +set(GTEST_CXX_FLAGS + -Wno-undef + -Wno-reserved-identifier + -Wno-global-constructors + -Wno-missing-noreturn + -Wno-disabled-macro-expansion + -Wno-used-but-marked-unused + -Wno-switch-enum + -Wno-zero-as-null-pointer-constant + -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast + -Wno-deprecated + -Wno-unsafe-buffer-usage + -Wno-float-equal +) + +if(WIN32) + list(APPEND GTEST_CXX_FLAGS + -Wno-suggest-destructor-override + -Wno-suggest-override + -Wno-nonportable-system-include-path + -Wno-language-extension-token) +endif() + +target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ca0d12821067edb2ebdd8b763a894168e70e647 --- /dev/null +++ b/codegen/CMakeLists.txt @@ -0,0 +1,58 @@ +cmake_minimum_required(VERSION 3.16) +project(composable_kernel_host) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) + +find_package(ROCM) +include(ROCMInstallTargets) +include(ROCMTest) + +rocm_setup_version(VERSION 1.0) + +list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) +include(Embed) +file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS + ${CK_ROOT}/include/ck/*.hpp) +# printouts fot debug purposes +# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") +# message(STATUS "RELATIVE: ${CK_ROOT}/include") +add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) + +add_compile_options(-std=c++17) + +file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) +# TODO: Use object library +add_library(ck_host STATIC ${SOURCES}) +target_link_libraries(ck_host PRIVATE ck_headers) + +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) + +# target_include_directories(ck_host PUBLIC +# $ +# ) + +add_executable(ck-template-driver driver/main.cpp) +target_link_libraries(ck-template-driver ck_host) + +rocm_install_targets( + TARGETS ck_host ck_headers + EXPORT ck_host_targets + INCLUDE include + PRIVATE +) +rocm_export_targets( + EXPORT ck_host_targets + NAMESPACE composable_kernel:: +) + +if(BUILD_TESTING) + add_subdirectory(test) +endif() + diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c7d295de943e1feb5b139d933118db745e7edee3 --- /dev/null +++ b/codegen/driver/main.cpp @@ -0,0 +1,105 @@ + +#include +#include +#include +#include +#include +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/stringutils.hpp" + +using ck::host::Transform; + +struct Emitters +{ + // retrieve the hard-coded instances provided, template them, and then store them in a map + std::unordered_map()>> m; + + template + void Register(const std::string& name, const std::string& prologue, const std::string& epilogue) + { + m[name] = [&] { + auto configs = T::CreateOperations(prologue, epilogue); + + return Transform(configs, [](const auto& ops) { return ToTuple(ops); }); + }; + } + + // takes in an operation instance and uses it to substitute the correct values into the template + template + static std::string ToTuple(const T& ops) + { + auto templates = Transform( + ops, [](const auto& op) { return " " + op.ToSolution().ToTemplateString(); }); + return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">"; + } + + // Join together all the strings in the map + std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); } + + std::vector List() const + { + return Transform(m, [](auto&& p) { return p.first; }); + } +}; + +int main(int argc, const char* argv[]) +{ + std::string prog = argv[0]; + std::vector args(argv + 1, argv + argc); + + // Specify problem type and problem size + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + + // user provided fusion + std::string prologue = ""; + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};)"; + + // Load in operations into the Register + Emitters e; + e.Register( + "DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue); + + if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) { + return arg == "-h" or arg == "--help"; + })) + { + std::cout << "USAGE:" << std::endl; + std::cout << " " << prog << " [TEMPLATE]" << std::endl; + std::cout << std::endl; + std::cout << "FLAGS:" << std::endl; + std::cout << " -h, --help Show help" << std::endl; + std::cout << std::endl; + std::cout << "TEMPLATES:" << std::endl; + for(auto x : e.List()) + std::cout << " " << x << std::endl; + std::cout << std::endl; + return 0; + } + + // print out all the instances for the operation that was chosen at the command line + for(auto name : args) + std::cout << e.Emit(name) << std::endl; + + return 0; +} diff --git a/codegen/include/ck/host/device_gemm_multiple_d.hpp b/codegen/include/ck/host/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..02c19c88e7335c467da119dd46fb3408b80b319e --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::Tuple<>"; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..359da7d8cf5ee48aab4cd6a4e987a49830fac88c --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_gemm_multiple_d/problem.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +// defines all values need for an instance of fwd conv +struct Operation_Xdl_CShuffle +{ + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = Bilinear; + std::string prologue = ""; + std::string epilogue = ""; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f4036328ecc6848bc559a371d0d05fcbf482374e --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +// defines the problem specification for a GEMM operation +struct Problem +{ + // dimensions for GEMM operation + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + // layouts for tensors + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string CDEElementOp = PassThrough; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5ad1dce1762b847f7385400608d5a776317676b3 --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the values needed for an instance of forward convolution and functions to return +// (templated) instances +struct Operation_Conv_Fwd_Xdl_Cshuffle +{ + // returns a vector of instances given the fusion operations, uses default values for problem + // spec + static std::vector + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, provided with a problem spec and fusion operations + static std::vector CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue); + std::size_t NumDim; + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = PassThrough; + std::string prologue = ""; + std::string epilogue = ""; + std::string conv_specialization = + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"; + std::string gemm_specialization = + "ck::tensor_operation::device::GemmSpecialization::MNKPadding"; + // tuning parameters + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + // functions to update fusion operations if they are provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..433f9a8fc92f100486780f0a6d6e246d2e83a43f --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the problem specification for a forward convolution operation +struct Problem_Conv_Fwd +{ + std::size_t NumDim = 0; + // size of a forward convolution operation + std::size_t G = 0; + std::size_t N = 0; + std::size_t C = 0; + std::size_t Hi = 0; + std::size_t Wi = 0; + std::size_t Ho = 0; + std::size_t Wo = 0; + std::size_t K = 0; + std::size_t Y = 0; + std::size_t X = 0; + Layout ALayout = Layout::NHWGC; + Layout BLayout = Layout::GKYXC; + Layout ELayout = Layout::NHWGK; + std::vector DsLayout = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::tensor_operation::element_wise::PassThrough"; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54f8d9f7317c3511a12c723bd6eadfaba711ffa9 --- /dev/null +++ b/codegen/include/ck/host/headers.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace host { + +std::unordered_map GetHeaders(); + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..84ef92f0a039706e1da4719ca9576667fac44494 --- /dev/null +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +namespace host { +namespace operation { + +struct TileDesc +{ + int block_size = 0; + int m_per_block = 0; + int n_per_block = 0; + int k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int m_Xdl_per_wave = 0; + int n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; +struct BlockTransferDesc +{ + std::string thread_cluster_length = ""; + std::string thread_cluster_arrange_order = ""; + std::string src_access_order = ""; + int src_vec_dim = 0; + int src_scalar_per_vector = 0; + int dst_scalar_per_vector_k1 = 0; + int lds_add_extra_dim = 0; +}; +struct CShuffleDesc +{ + int m_Xdl_per_wave_per_shuffle = 0; + int n_Xdl_per_wave_per_shuffle = 0; +}; +struct CBlockTransferDesc +{ + std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = ""; + int scalar_per_vector_n_wave_n_per_Xdl = 0; +}; + +} // namespace operation +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89c1884d2e4283604eecf33cf6b5dffe42e1c367 --- /dev/null +++ b/codegen/include/ck/host/stringutils.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +template +std::string trim(const std::string& s, F f) +{ + auto start = std::find_if_not(s.begin(), s.end(), f); + auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base(); + return {start, last}; +} + +inline std::string trim(const std::string& s) +{ + return trim(s, [](unsigned char c) { return std::isspace(c); }); +} + +template +inline std::string JoinStrings(Strings strings, const std::string& delim) +{ + auto it = strings.begin(); + if(it == strings.end()) + return ""; + + auto nit = std::next(it); + return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) { + return std::move(x) + delim + std::move(y); + }); +} + +template +inline std::string +InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}") +{ + std::string result = ""; + result.reserve(input.size()); + auto it = input.begin(); + while(it != input.end()) + { + auto next_start = std::search(it, input.end(), start.begin(), start.end()); + auto next_end = std::search(next_start, input.end(), end.begin(), end.end()); + result.append(it, next_start); + if(next_start == input.end()) + break; + if(next_end == input.end()) + { + throw std::runtime_error("Unbalanced brackets"); + } + auto r = f(next_start + start.size(), next_end); + result.append(r.begin(), r.end()); + it = next_end + end.size(); + } + return result; +} +inline std::string InterpolateString(const std::string& input, + const std::unordered_map& vars, + std::string start = "${", + std::string end = "}") +{ + return InterpolateString( + input, + [&](auto start_it, auto last_it) { + auto key = trim({start_it, last_it}); + auto it = vars.find(key); + if(it == vars.end()) + throw std::runtime_error("Unknown key: " + key); + return it->second; + }, + std::move(start), + std::move(end)); +} + +template +inline auto Transform(const Range& r, F f) -> std::vector +{ + std::vector result; + std::transform(r.begin(), r.end(), std::back_inserter(result), f); + return result; +} + +template +inline auto Transform(const Range1& r1, const Range2& r2, F f) + -> std::vector +{ + std::vector result; + assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); + std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f); + return result; +} + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8bad7bf89c55fccb6e052347f0b5689d20200e4a --- /dev/null +++ b/codegen/include/ck/host/types.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +// holds the templated instance, substitues values into template from instancess +struct Solution +{ + + Solution() = default; + Solution(std::string str, std::unordered_map values); + std::string ToTemplateString() const; + std::string GetTemplateParameter(const std::string& name) const; + template + T GetTemplateParameter(const std::string& name) const + { + T result; + std::stringstream ss(GetTemplateParameter(name)); + ss >> result; + return result; + } + + private: + std::string template_str; + std::unordered_map template_values; +}; + +// supported data types +enum class DataType +{ + Half, + Float, + Int8, + Int32 +}; +std::string ToString(DataType dt); + +// supported layouts: gemm and fwd conv +enum class Layout +{ + Row, + Column, + GKYXC, + GKCYX, + GNHWK, + GNHWC, + NHWGC, + NHWGK +}; +std::string ToString(Layout dl); +Layout ToLayout(bool Trans); // returns the layout for gemm + +// supported GEMM types +enum class GemmType +{ + Default +}; +std::string ToString(GemmType gt); + +struct TensorDesc +{ + DataType element; + Layout layout; +}; + +std::string SequenceStr(const std::vector& v); + +std::string MakeTuple(const std::vector& v); + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +template +const std::string S = SequenceStr({xs...}); +#pragma clang diagnostic pop + +constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; +constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/utils.hpp b/codegen/include/ck/host/utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21926814f1d27a0727d952a1a5674e7ada520012 --- /dev/null +++ b/codegen/include/ck/host/utils.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace host { + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y); + +const std::unordered_set& get_xdlop_archs(); +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d.cpp b/codegen/src/device_gemm_multiple_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44bc051a8b021e1e3a7461732034dc77e063ae76 --- /dev/null +++ b/codegen/src/device_gemm_multiple_d.cpp @@ -0,0 +1,38 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +// return the relevant device op file based on the operation +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; +} + +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); // template instance with correct values + }); + return result; +} + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fff75c196263961394445bd27162968a563e112b --- /dev/null +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -0,0 +1,347 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/types.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +// calculate appropriate Gemm Specification based on input tensor dimensions +static std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) +{ + if(!pro.empty()) + { + this->prologue = pro; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) +{ + if(!epi.empty()) + { + this->epilogue = epi; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, + // clang-format on + }; + + std::vector a_block_descriptions_rowmajor = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + // clang-format on + }; + + std::vector a_block_descriptions_colmajor = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + // clang-format on + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, + }; + + std::vector b_block_descriptions_rowmajor = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + // clang-format on + }; + + std::vector b_block_descriptions_colmajor = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 4>, 8}, + { S<1, 16, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + // choose correct arrangement of tuning parameters based on the layout of each tensor + const auto a_block_descriptions = + prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor; + const auto b_block_descriptions = + prob.TransB ? b_block_descriptions_colmajor : b_block_descriptions_rowmajor; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) +{ + std::vector problems; + for(bool TransA : {true, false}) + for(bool TransB : {true, false}) + { + Problem prob; + prob.TransA = TransA; + prob.TransB = TransB; + problems.push_back(prob); + } + return Transform(problems, + [&](const Problem& p) { return CreateOperations(p, prologue, epilogue); }); +} + +static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<${LayoutA}, ${LayoutB}, " + "${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, " + "${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, " + "${CDEElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, " + "${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, " + "${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, " + "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " + "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " + "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c689e5ec950e5a8eae22a8328150a895f4c03048 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/utils.hpp" +#include +#include + +namespace ck { +namespace host { +namespace conv { + +// return the relevant device op file based on the operation +// NOTE: this is a modified version of the original CK file that calls the kernel from a device +// function and makes the Argument class accessible on the device +std::string Problem_Conv_Fwd::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/" + "codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"; +} + +// return vector of forward convolution instances when provided with a problem instance +std::vector Problem_Conv_Fwd::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::conv::Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + *this, prologue, epilogue); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36c9a13b4c340b907612f62249f75c15f703aea5 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include +#include "ck/host/stringutils.hpp" +#include "ck/host/types.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace conv { + +// NOTE: in CK, MNKPadding is always used for forward convolution, so didn't +// add GemmSpec function here + +// function to update prologue/epilogue with user provided operation +void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& pro) +{ + if(!pro.empty()) + { + this->prologue = pro; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epi) +{ + if(!epi.empty()) + { + this->epilogue = epi; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1} + // clang-format on + }; + + std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector b_block_descriptions = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1} + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8} + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Conv_Fwd_Xdl_Cshuffle x; + x.NumDim = prob.NumDim; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, prob.ALayout}; + x.B = TensorDesc{prob.BDataType, prob.BLayout}; + x.E = TensorDesc{prob.EDataType, prob.ELayout}; + x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { + return TensorDesc{dt, lo}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values +std::vector +Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(const std::string& prologue, + const std::string& epilogue) +{ + Problem_Conv_Fwd prob; + return CreateOperations(prob, prologue, epilogue); +} + +static const char* const CopyDevice_ConvTemplate = + R"( +${Prologue} +${Epilogue} + +using CDEElementOp = Epilogue; +using DeviceConv = ck::tensor_operation::device::CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<${NumDim}, ${LayoutA}, ${LayoutB}, ${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, ${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, ${CDEElementwiseOperation}, ${ConvSpecialization}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, ${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, ${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, ${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, ${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, ${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, ${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, ${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, ${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, ${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, ${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, ${CDEBlockTransferScalarPerVector_NPerBlock}>; + +constexpr ck::index_t NumATensor = ck::tensor_operation::device::GetNumABTensors(); +constexpr ck::index_t NumBTensor = ck::tensor_operation::device::GetNumABTensors(); + +extern "C" __global__ void run_${name}( + const ${ADataType}* in_dev, + const ${BDataType}* wei_dev, + ${EDataType}* __restrict__ out_dev, + ck::Array in_lengths, + ck::Array in_strides, + ck::Array wei_lengths, + ck::Array wei_strides, + ck::Array out_lengths, + ck::Array out_strides, + ck::Array conv_filter_strides, + ck::Array conv_filter_dilations, + ck::Array input_left_pads, + ck::Array input_right_pads, + const ${AElementwiseOperation} a_element_op, + const ${BElementwiseOperation} b_element_op, + const ${CDEElementwiseOperation} cde_element_op +){ + + + auto arg = DeviceConv::Argument(in_dev, + wei_dev, + ck::Array{}, + out_dev, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + ck::Array, 0>{}, + ck::Array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + ${AElementwiseOperation}{}, + ${BElementwiseOperation}{}, + ${CDEElementwiseOperation}{1.0f, 1.0f}); + + if(!DeviceConv::IsSupportedArgument(arg)) + { + printf("Arguement is not supported.\n"); + return; + }; + + constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); + + // GridwiseGemm + using GridwiseGemm = DeviceConv::GridwiseGemm; + + static constexpr auto I0 = ck::Number<0>{}; + + ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ${ADataType}*, + const ${BDataType}*, + typename GridwiseGemm::DsGridPointer, + ${EDataType}, + ${AElementwiseOperation}, + ${BElementwiseOperation}, + ${CDEElementwiseOperation}, + DeviceConv::AGridDesc_AK0_M_AK1, + DeviceConv::BGridDesc_BK0_N_BK1, + DeviceConv::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::Block2ETileMap, + ck::tensor_operation::device::ComputePtrOffsetOfStridedBatch, + ck::integral_constant{}, + false, + false> + ( + arg.p_as_grid_.At(I0), + arg.p_bs_grid_.At(I0), + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_ + ); + +} +)"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Conv_Fwd_Xdl_Cshuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"NumDim", std::to_string(this->NumDim)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"ComputeDataType", ToString(this->A.element)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"Prologue", this->prologue}, + {"Epilogue", this->epilogue}, + {"ConvSpecialization", this->conv_specialization}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(CopyDevice_ConvTemplate, values), std::move(values)}; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b0c929db32fc9a834ec040163b2a5476e2b6d1c --- /dev/null +++ b/codegen/src/headers.cpp @@ -0,0 +1,20 @@ +#include "ck/host/headers.hpp" +#include "ck_headers.hpp" + +namespace ck { +namespace host { + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +const std::string config_header = ""; +#pragma clang diagnostic pop + +std::unordered_map GetHeaders() +{ + auto headers = ck_headers(); + headers.insert(std::make_pair("ck/config.h", config_header)); + return headers; +} + +} // namespace host +} // namespace ck diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8a8b10c04d522e93dc7340167e46ca83f51259b --- /dev/null +++ b/codegen/src/types.cpp @@ -0,0 +1,71 @@ +#include "ck/host/types.hpp" +#include "ck/host/stringutils.hpp" +#include +#include + +namespace ck { +namespace host { + +Solution::Solution(std::string str, std::unordered_map values) + : template_str(std::move(str)), template_values(std::move(values)) +{ +} + +std::string Solution::ToTemplateString() const { return this->template_str; } +std::string Solution::GetTemplateParameter(const std::string& name) const +{ + return this->template_values.at(name); +} + +std::string ToString(DataType dt) +{ + switch(dt) + { + case DataType::Float: return "float"; + case DataType::Half: return "ck::half_t"; + case DataType::Int8: return "int8_t"; + case DataType::Int32: return "int32_t"; + } + throw std::runtime_error("Incorrect data type"); +} + +Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +std::string ToString(Layout dl) +{ + switch(dl) + { + case Layout::Row: return "ck::tensor_layout::gemm::RowMajor"; + case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor"; + case Layout::GKCYX: return "ck::tensor_layout::convolution::GKCYX"; + case Layout::GKYXC: return "ck::tensor_layout::convolution::GKYXC"; + case Layout::GNHWK: return "ck::tensor_layout::convolution::GNHWK"; + case Layout::GNHWC: return "ck::tensor_layout::convolution::GNHWC"; + case Layout::NHWGC: return "ck::tensor_layout::convolution::NHWGC"; + case Layout::NHWGK: return "ck::tensor_layout::convolution::NHWGK"; + } + throw std::runtime_error("Incorrect layout"); +} + +std::string ToString(GemmType gt) +{ + switch(gt) + { + case GemmType::Default: return "ck::tensor_operation::device::GemmSpecialization::Default"; + } + throw std::runtime_error("Incorrect gemm type"); +} + +std::string SequenceStr(const std::vector& v) +{ + return "ck::Sequence<" + + JoinStrings(Transform(v, [](int x) { return std::to_string(x); }), ", ") + ">"; +} + +std::string MakeTuple(const std::vector& v) +{ + return "ck::Tuple<" + JoinStrings(v, ", ") + ">"; +} + +} // namespace host +} // namespace ck diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..19627d4cf6fa6bcdacee4519a1929b93093cf527 --- /dev/null +++ b/codegen/src/utils.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/utils.hpp" + +namespace ck { +namespace host { + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y) +{ + return (x + y - std::size_t{1}) / y; +} + +const std::unordered_set& get_xdlop_archs() +{ + static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"}; + return supported_archs; +} + +} // namespace host +} // namespace ck diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..48fde531da6bae4ccf1f4a994c55b24244557762 --- /dev/null +++ b/codegen/test/CMakeLists.txt @@ -0,0 +1,25 @@ +list(APPEND CMAKE_PREFIX_PATH /opt/rocm) +add_subdirectory(rtc) +file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) + +# TODO: These tests need to be refactored to remove dependency on main ck +# headers and device compilation. +set(TESTS_REQUIRE_DEVICE_COMPILE + grouped_conv_fwd_multiple_d_v1 + grouped_conv_fwd_multiple_d_v2 + grouped_conv_fwd_multiple_d_v3 + grouped_conv_fwd_multiple_d_v4 +) +find_package(hip) + +foreach(TEST_SRC ${TEST_SRCS}) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) + target_include_directories(codegen_test_${BASE_NAME} PUBLIC include) + if(BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE) + target_link_libraries(codegen_test_${BASE_NAME} hip::device) + target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/include) + target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) + endif() +endforeach() diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd7ef463fbe64d5bc3d07665cb4757598657f2ad --- /dev/null +++ b/codegen/test/gemm_multiple_d.cpp @@ -0,0 +1,190 @@ +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +using half = _Float16; +// using half = __fp16; + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) +{ + rtc::buffer result(n); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); })) + return false; + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} + +const std::string gemm_compile_check = R"__ck__( +#include <${include}> + +extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { + using G = ${template}; + constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})), + ck::make_tuple(), + ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n}))); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + if constexpr(desc.IsValid()) + { + ${template}::Run(desc, + a, + b, + ck::make_tuple(), + c); + } +} + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + check_all check; + auto a = to_gpu(generate_buffer(1024 * 1024, 0)); + auto b = to_gpu(generate_buffer(1024 * 1024, 1)); + auto c = to_gpu(generate_buffer(1024 * 1024, 2)); + + std::string epilogue = ""; + std::string prologue = ""; + + for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) + { + auto src = ck::host::InterpolateString(gemm_compile_check, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + options.kernel_name = "f"; + auto k = rtc::compile_kernel(srcs, options); + auto block_size = solution.GetTemplateParameter("BlockSize"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * + ck::host::integer_divide_ceil(prob.N, n_per_block); + k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); + + CHECK(report(solution, check(rtc::from_gpu(c)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50290fa25ad385ce1e657de8cac3042227cc6787 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -0,0 +1,207 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include "common.hpp" +#include + +// Need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b558d97c783c7f9a2a0901013c5ea4cee053d175 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -0,0 +1,207 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include + +// need this for validation +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e2972a93d2f7e11aa6b03dc35b7aae6663e70f93 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -0,0 +1,207 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b728096c51e80d0d72f407160275e3699cf5a16a --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -0,0 +1,207 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99d4c6497331f65d19adf302bd47dbaa22ac4b40 --- /dev/null +++ b/codegen/test/include/common.hpp @@ -0,0 +1,134 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +std::size_t GetSize(V mLens, V mStrides) +{ + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.Size(); ++i) + { + if(mLens[i] == 0) + continue; + + space += (mLens[i] - 1) * mStrides[i]; + } + return space; +} + +template +rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + rtc::buffer result(space); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + // std::fill(result.begin(), result.end(), 1); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} diff --git a/codegen/test/include/test.hpp b/codegen/test/include/test.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3e38d6002621b8fcb3aa283b2fde4987930f715 --- /dev/null +++ b/codegen/test/include/test.hpp @@ -0,0 +1,848 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include +#endif + +#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP +#define MIGRAPHX_GUARD_TEST_TEST_HPP + +namespace test { +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_BINARY_OPERATORS(m) \ + m(==, equal) \ + m(!=, not_equal) \ + m(<=, less_than_equal) \ + m(>=, greater_than_equal) \ + m(<, less_than) \ + m(>, greater_than) \ + m(and, and_op) \ + m(or, or_op) +// clang-format on + +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_UNARY_OPERATORS(m) \ + m(not, not_op) +// clang-format on + +// NOLINTNEXTLINE +#define TEST_EACH_BINARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x, U&& y) \ + { \ + return x op y; \ + } \ + }; + +// NOLINTNEXTLINE +#define TEST_EACH_UNARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x) \ + { \ + return op x; \ + } \ + }; + +TEST_FOREACH_BINARY_OPERATORS(TEST_EACH_BINARY_OPERATOR_OBJECT) +TEST_FOREACH_UNARY_OPERATORS(TEST_EACH_UNARY_OPERATOR_OBJECT) + +struct nop +{ + static std::string as_string() { return ""; } + template + static auto call(T&& x) + { + return static_cast(x); + } +}; + +struct function +{ + static std::string as_string() { return ""; } + template + static decltype(auto) call(T&& x) + { + return x(); + } +}; + +template +Stream& stream_range(Stream& s, Iterator start, Iterator last); + +template +inline Stream& operator<<(Stream& s, std::nullptr_t) +{ + s << "nullptr"; + return s; +} + +template {}>::type> +inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end())) +{ + s << "{ "; + stream_range(s, v.begin(), v.end()); + s << "}"; + return s; +} + +template +inline Stream& stream_range(Stream& s, Iterator start, Iterator last) +{ + if(start != last) + { + s << *start; + std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; }); + } + return s; +} + +template +const T& get_value(const T& x) +{ + return x; +} + +template +struct lhs_expression; + +template +lhs_expression make_lhs_expression(T&& lhs); + +template +lhs_expression make_lhs_expression(T&& lhs, Operator); + +// NOLINTNEXTLINE +#define TEST_EXPR_BINARY_OPERATOR(op, name) \ + template \ + auto operator op(const V& rhs2) const \ + { \ + return make_expression(*this, rhs2, name{}); /* NOLINT */ \ + } + +// NOLINTNEXTLINE +#define TEST_EXPR_UNARY_OPERATOR(op, name) \ + auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } + +template +struct expression +{ + T lhs; + U rhs; + + friend std::ostream& operator<<(std::ostream& s, const expression& self) + { + s << self.lhs << " " << Operator::as_string() << " " << self.rhs; + return s; + } + + friend decltype(auto) get_value(const expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); }; + + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) +}; + +// TODO: Remove rvalue references +template +expression make_expression(T&& rhs, U&& lhs, Operator) +{ + return {std::forward(rhs), std::forward(lhs)}; +} + +// TODO: Remove rvalue reference +template +lhs_expression make_lhs_expression(T&& lhs) +{ + return lhs_expression{std::forward(lhs)}; +} + +template +lhs_expression make_lhs_expression(T&& lhs, Operator) +{ + return lhs_expression{std::forward(lhs)}; +} + +template +struct lhs_expression +{ + T lhs; + explicit lhs_expression(T e) : lhs(e) {} + + friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) + { + std::string op = Operator::as_string(); + if(not op.empty()) + s << Operator::as_string() << " "; + s << self.lhs; + return s; + } + + friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs)); } + + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + +// NOLINTNEXTLINE +#define TEST_LHS_REOPERATOR(op) \ + template \ + auto operator op(const U& rhs) const \ + { \ + return make_lhs_expression(lhs op rhs); \ + } + TEST_LHS_REOPERATOR(+) + TEST_LHS_REOPERATOR(-) + TEST_LHS_REOPERATOR(*) + TEST_LHS_REOPERATOR(/) + TEST_LHS_REOPERATOR(%) + TEST_LHS_REOPERATOR(&) + TEST_LHS_REOPERATOR(|) + TEST_LHS_REOPERATOR(^) +}; + +template +struct predicate +{ + std::string msg; + F f; + + friend std::ostream& operator<<(std::ostream& s, const predicate& self) + { + s << self.msg; + return s; + } + + decltype(auto) operator()() const { return f(); } + + operator decltype(auto)() const { return f(); } +}; + +template +auto make_predicate(const std::string& msg, F f) +{ + return make_lhs_expression(predicate{msg, f}, function{}); +} + +inline std::string as_string(bool x) +{ + if(x) + return "true"; + return "false"; +} + +template +std::string as_string(const T& x) +{ + std::stringstream ss; + ss << x; + return ss.str(); +} + +template +std::string as_string(Iterator start, Iterator last) +{ + std::stringstream ss; + stream_range(ss, start, last); + return ss.str(); +} + +template +auto make_function(const std::string& name, F f) +{ + return [=](auto&&... xs) { + std::vector args = {as_string(xs)...}; + return make_predicate(name + "(" + as_string(args.begin(), args.end()) + ")", + [=] { return f(xs...); }); + }; +} + +struct capture +{ + template + auto operator->*(const T& x) const + { + return make_lhs_expression(x); + } + + template + auto operator->*(const lhs_expression& x) const + { + return x; + } +}; + +enum class color +{ + reset = 0, + bold = 1, + underlined = 4, + fg_red = 31, + fg_green = 32, + fg_yellow = 33, + fg_blue = 34, + fg_default = 39, + bg_red = 41, + bg_green = 42, + bg_yellow = 43, + bg_blue = 44, + bg_default = 49 +}; +inline std::ostream& operator<<(std::ostream& os, const color& c) +{ +#ifndef _WIN32 + static const bool use_color = isatty(STDOUT_FILENO) != 0; + if(use_color) + return os << "\033[" << static_cast(c) << "m"; +#else + (void)c; +#endif + return os; +} + +inline std::atomic& failures() +{ + // NOLINTNEXTLINE + static std::atomic f = 0; + return f; +} + +template +void failed(T x, const char* msg, const char* func, const char* file, int line, F f) +{ + if(not bool(x.value())) + { + failures()++; + std::cout << func << std::endl; + std::cout << file << ":" << line << ":" << std::endl; + std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " " + << "[ " << x << " ]" << std::endl; + f(); + } +} + +template +bool throws(F f) +{ + try + { + f(); + return false; + } + catch(...) + { + return true; + } +} + +template +bool throws(F f, const std::string& msg = "") +{ + try + { + f(); + return false; + } + catch(const Exception& ex) + { + return std::string(ex.what()).find(msg) != std::string::npos; + } +} + +template +auto within_abs(T px, U py, double ptol = 1e-6f) +{ + return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })( + px, py, ptol); +} + +// This implements the basic globbing algorithm where `*` matches any number +// of characters(including none) and `?` matches any single character. It +// doesnt support character classes. +// +// This is a simple recursive implementation that scans the string where the +// string and pattern matches. When a `*` is found in the pattern, the +// `glob_match` function is called recursively to compare the rest of the +// pattern to the rest of the string. If the recursive call returns true, +// then we have a match. However, if it returns false, then we advance one +// character and call the recusrsive call again. This is referred to as a +// star-loop, which will consume zero or more characters. +// +// This simple recursive implementation works well for short string and +// patterns with few stars. First, it is unlikely to use many stars to glob +// test names. Secondly, using many stars is still signficantly faster than +// using the equivalent std::regex, which has a much slower time complexity. +template +bool glob_match(Iterator1 start, Iterator1 last, Iterator2 pattern_start, Iterator2 pattern_last) +{ + std::tie(start, pattern_start) = + std::mismatch(start, last, pattern_start, pattern_last, [](auto c, auto m) { + if(m == '?') + return true; + // We need a loop for star, so bail and handle the loop below + if(m == '*') + return false; + return c == m; + }); + // If there is no more pattern then return true if there is no more string to match + if(pattern_start == pattern_last) + return start == last; + // If the pattern is not a star then its a mismatch + if(*pattern_start != '*') + return false; + // Multiple stars are the same as a single star so skip over multiple stars + pattern_start = std::find_if(pattern_start, pattern_last, [](auto c) { return c != '*'; }); + // If the star is at the end then return true + if(pattern_start == pattern_last) + return true; + // star-loop: match the rest of the pattern and text + while(not glob_match(start, last, pattern_start, pattern_last) and start != last) + start++; + // If the string is empty then it means a match was never found + return start != last; +} + +using string_map = std::unordered_map>; + +template +string_map generic_parse(std::vector as, Keyword keyword) +{ + string_map result; + + std::string flag; + for(auto&& x : as) + { + auto f = keyword(x); + if(f.empty()) + { + result[flag].push_back(x); + } + else + { + flag = f.front(); + result[flag]; // Ensure the flag exists + flag = f.back(); + } + } + return result; +} + +using test_case = std::function; + +inline auto& get_test_cases() +{ + // NOLINTNEXTLINE + static std::vector> cases; + return cases; +} + +inline void add_test_case(std::string name, test_case f) +{ + get_test_cases().emplace_back(std::move(name), std::move(f)); +} + +struct auto_register_test_case +{ + template + auto_register_test_case(const char* name, F f) noexcept + { + add_test_case(name, f); + } +}; + +struct failure_error +{ +}; + +[[noreturn]] inline void fail() { throw failure_error{}; } + +struct driver +{ + driver() + { + add_flag({"--help", "-h"}, "Show help"); + add_flag({"--list", "-l"}, "List all test cases"); + add_flag({"--continue", "-c"}, "Continue after failure"); + add_flag({"--quiet", "-q"}, "Don't print out extra output"); + } + struct argument + { + std::vector flags = {}; + std::string help = ""; + int nargs = 1; + }; + + void add_arg(const std::vector& flags, const std::string& help = "") + { + arguments.push_back(argument{flags, help, 1}); + } + + void add_flag(const std::vector& flags, const std::string& help = "") + { + arguments.push_back(argument{flags, help, 0}); + } + + static void wrap(std::ostream& os, + const std::string& text, + const std::string& prefix = "", + unsigned int line_length = 80) + { + std::istringstream iss(text); + std::string line = prefix; + do + { + std::string word; + iss >> word; + if(line.length() + word.length() > line_length) + { + os << line << std::endl; + line = prefix; + } + line += word + " "; + } while(iss); + if(not line.empty()) + os << line << std::endl; + } + + void show_help(const std::string& exe) const + { + const std::string prefix = " "; + std::cout << std::endl; + std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; + std::cout << " "; + std::cout << exe << " ... " << std::endl; + std::cout << std::endl; + + std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl; + std::cout << " "; + std::cout << color::fg_green << "..." << color::reset; + std::cout << std::endl; + + wrap(std::cout, + "Test cases to run. A test case can be either the exact test case name or a glob. A " + "glob expression uses a '*' to select zero or more characters or a '?' to select any " + "single character.", + prefix + prefix); + + std::cout << std::endl; + std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl; + for(auto&& arg : arguments) + { + std::cout << color::fg_green; + std::string arg_prefix = prefix; + for(const std::string& a : arg.flags) + { + std::cout << arg_prefix; + std::cout << a; + arg_prefix = ", "; + } + std::cout << color::reset << std::endl; + wrap(std::cout, arg.help, prefix + prefix); + } + } + + std::ostream& out() const + { + struct null_buffer : std::streambuf + { + virtual int overflow(int c) override { return c; } + }; + static null_buffer buffer; + static std::ostream null_stream(&buffer); + if(quiet) + return null_stream; + return std::cout; + } + + string_map parse(int argc, const char* argv[]) const + { + std::vector args(argv + 1, argv + argc); + string_map keys; + for(auto&& arg : arguments) + { + for(auto&& flag : arg.flags) + { + keys[flag] = {arg.flags.front()}; + if(arg.nargs == 0) + keys[flag].push_back(""); + } + } + auto result = generic_parse(args, [&](auto&& s) -> std::vector { + if(keys.count(s) > 0) + return keys[s]; + else + return {}; + }); + result["__exe__"].push_back(argv[0]); + return result; + } + + static std::string create_command(const string_map& args) + { + std::stringstream ss; + ss << args.at("__exe__").front(); + if(args.count("") > 0) + { + for(auto&& arg : args.at("")) + ss << " \"" << arg << "\""; + } + for(auto&& p : args) + { + if(p.first == "__exe__") + continue; + if(p.first.empty()) + continue; + ss << " " << p.first; + for(auto&& arg : p.second) + ss << " \"" << arg << "\""; + } + return ss.str(); + } + + static std::string fork(const std::string& name, string_map args) + { + std::string msg; + args[""] = {name}; + args.erase("--continue"); + args["--quiet"]; + auto cmd = create_command(args); + auto r = std::system(cmd.c_str()); // NOLINT + if(r != 0) + msg = "Exited with " + std::to_string(r); + return msg; + } + + static std::vector> glob_tests(const std::string& pattern) + { + std::vector> result; + std::copy_if(get_test_cases().begin(), + get_test_cases().end(), + std::back_inserter(result), + [&](auto&& p) { + return glob_match( + p.first.begin(), p.first.end(), pattern.begin(), pattern.end()); + }); + return result; + } + + void run_test_case(const std::string& name, const test_case& f, const string_map& args) + { + ran++; + out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name + << color::reset << std::endl; + std::string msg; + auto start = std::chrono::steady_clock::now(); + if(args.count("--continue") > 0) + { + msg = fork(name, args); + } + else + { + try + { + failures() = 0; + f(); + } + // cppcheck-suppress migraphx-EmptyCatchStatement + catch(const failure_error&) + { + } + } + auto finish = std::chrono::steady_clock::now(); + auto elapsed_ms = + std::chrono::duration_cast>(finish - start) + .count(); + if(msg.empty() and failures() != 0) + { + if(failures() == 1) + msg = "Test failure"; + else + msg = std::to_string(failures()) + " test failures"; + } + if(msg.empty()) + { + out() << color::fg_green << "[ COMPLETE ] " << color::reset; + } + else + { + failed.push_back(name); + out() << color::fg_red << "[ FAILED ] " << color::reset; + } + out() << color::bold << name << color::reset; + out() << color::fg_blue << " (" << elapsed_ms << "ms)" << color::reset; + if(not msg.empty()) + out() << ": " << color::fg_yellow << msg << color::reset; + out() << std::endl; + } + + void run(int argc, const char* argv[]) + { + auto args = parse(argc, argv); + if(args.count("--help") > 0) + { + show_help(args.at("__exe__").front()); + return; + } + if(args.count("--list") > 0) + { + for(auto&& tc : get_test_cases()) + out() << tc.first << std::endl; + return; + } + + if(args.count("--quiet") > 0) + quiet = true; + + auto cases = args[""]; + if(cases.empty()) + { + for(auto&& tc : get_test_cases()) + run_test_case(tc.first, tc.second, args); + } + else + { + std::unordered_map m(get_test_cases().begin(), + get_test_cases().end()); + + for(auto&& iname : cases) + { + std::vector> found_cases; + for(auto&& pattern : get_case_names(iname)) + { + auto f = m.find(pattern); + if(f == m.end()) + { + found_cases = glob_tests(pattern); + } + else + { + found_cases.push_back(*f); + } + } + if(found_cases.empty()) + { + out() << color::fg_red << "[ ERROR ] Test case '" << iname << "' not found." + << color::reset << std::endl; + failed.push_back(iname); + } + for(auto&& p : found_cases) + run_test_case(p.first, p.second, args); + } + } + out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran" + << color::reset << std::endl; + if(not failed.empty()) + { + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size() + << " tests failed" << color::reset << std::endl; + for(auto&& name : failed) + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name + << color::reset << std::endl; + std::exit(1); + } + } + + std::function(const std::string&)> get_case_names = + [](const std::string& name) -> std::vector { return {name}; }; + std::vector arguments = {}; + std::vector failed = {}; + std::size_t ran = 0; + bool quiet = false; +}; + +inline void run(int argc, const char* argv[]) +{ + driver d{}; + d.run(argc, argv); +} + +} // namespace test + +// NOLINTNEXTLINE +#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__ + +// NOLINTNEXTLINE +#define CHECK(...) \ + test::failed( \ + TEST_CAPTURE(__VA_ARGS__), #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] {}) + +// NOLINTNEXTLINE +#define EXPECT(...) \ + test::failed(TEST_CAPTURE(__VA_ARGS__), \ + #__VA_ARGS__, \ + __PRETTY_FUNCTION__, \ + __FILE__, \ + __LINE__, \ + &test::fail) +// NOLINTNEXTLINE +#define STATUS(...) EXPECT((__VA_ARGS__) == 0) + +// NOLINTNEXTLINE +#define TEST_CAT(x, ...) TEST_PRIMITIVE_CAT(x, __VA_ARGS__) +// NOLINTNEXTLINE +#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__ + +// NOLINTNEXTLINE +#define TEST_CASE_REGISTER(...) \ + static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \ + test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__); + +// NOLINTNEXTLINE +#define TEST_CASE(...) \ + void __VA_ARGS__(); \ + TEST_CASE_REGISTER(__VA_ARGS__) \ + void __VA_ARGS__() + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +#endif diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..68bfc2467b089c723a2e3d25bc8dad23ce6d694b --- /dev/null +++ b/codegen/test/rtc/CMakeLists.txt @@ -0,0 +1,6 @@ +find_package(hip) +file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) +add_library(ck_rtc ${RTC_SOURCES}) +target_include_directories(ck_rtc PUBLIC include) +target_link_libraries(ck_rtc PUBLIC hip::host) +target_link_libraries(ck_rtc PUBLIC -lstdc++fs) diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4413b47be2b23a36dd2a631794876dae8b98776 --- /dev/null +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -0,0 +1,27 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL + +#include +#include +#include + +namespace rtc { + +struct src_file +{ + fs::path path; + std::string_view content; +}; + +struct compile_options +{ + std::string flags = ""; + std::string kernel_name = "main"; +}; + +kernel compile_kernel(const std::vector& src, + compile_options options = compile_options{}); + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/filesystem.hpp b/codegen/test/rtc/include/rtc/filesystem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b94b84b9fd93ad038bc9f28677a3d4b7dae7625 --- /dev/null +++ b/codegen/test/rtc/include/rtc/filesystem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP +#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP + +#include +#include + +// clang-format off +#if defined(CPPCHECK) + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 1 +#elif defined(_WIN32) + #if _MSC_VER >= 1920 + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 0 + #elif _MSC_VER >= 1900 + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#elif defined(__has_include) + #if __has_include() && __cplusplus >= 201703L + #define RTC_HAS_FILESYSTEM 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #endif + #if __has_include() && __cplusplus >= 201103L + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 +#endif +// clang-format on + +#if RTC_HAS_FILESYSTEM +#include +#elif RTC_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace rtc { + +#if RTC_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif RTC_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace rtc + +#endif // GUARD_RTC_FILESYSTEM_HPP_ diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6b523382dce35b08b109d70dcd15417b76d650a7 --- /dev/null +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -0,0 +1,78 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP + +#include +#include +#include + +namespace rtc { + +template +struct buffer +{ + buffer() : ptr(), n(0) {} + buffer(std::shared_ptr p, std::size_t sz) : ptr(p), n(sz) {} + buffer(std::shared_ptr p, std::size_t sz) + : ptr(std::reinterpret_pointer_cast(p)), n(sz) + { + } + explicit buffer(std::size_t sz) : ptr(new T[sz]), n(sz) {} + T* begin() { return data(); } + T* end() { return data() + size(); } + const T* begin() const { return data(); } + const T* end() const { return data() + size(); } + + T& front() { return data()[0]; } + T& back() { return data()[size() - 1]; } + T& operator[](std::size_t i) { return data()[i]; } + T& at(std::size_t i) + { + if(i >= size()) + throw std::runtime_error("Out of bounds"); + return data()[i]; + } + + const T& front() const { return data()[0]; } + const T& back() const { return data()[size() - 1]; } + const T& operator[](std::size_t i) const { return data()[i]; } + const T& at(std::size_t i) const + { + if(i >= size()) + throw std::runtime_error("Out of bounds"); + return data()[i]; + } + const T* data() const { return ptr.get(); } + T* data() { return ptr.get(); } + + std::size_t size() const { return n; } + std::size_t bytes() const { return size() * sizeof(T); } + + bool empty() const { return size() == 0; } + + private: + std::shared_ptr ptr; + std::size_t n; +}; + +std::string get_device_name(); +std::string hip_error(int error); + +std::shared_ptr allocate_gpu(std::size_t sz, bool host = false); +std::shared_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false); +std::shared_ptr read_from_gpu(const void* x, std::size_t sz); + +template +buffer to_gpu(const buffer& input) +{ + return {write_to_gpu(input.data(), input.bytes()), input.size()}; +} + +template +buffer from_gpu(const buffer& input) +{ + return {read_from_gpu(input.data(), input.bytes()), input.size()}; +} + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f38e90416e0d2363a921df1ac4268bbc82e55ff --- /dev/null +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -0,0 +1,62 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL + +#include +#include +#include +#include + +namespace rtc { + +struct kernel_argument +{ + template , + class = std::enable_if_t{}>> + kernel_argument(T&& x) : size(sizeof(U)), align(alignof(U)), data(&x) // NOLINT + { + } + std::size_t size; + std::size_t align; + void* data; +}; + +std::vector pack_args(const std::vector& args); + +struct kernel_impl; + +struct kernel +{ + kernel() = default; + kernel(const char* image, const std::string& name); + template + kernel(const std::vector& image, const std::string& name) + : kernel(reinterpret_cast(image.data()), name) + { + static_assert(sizeof(T) == 1, "Only byte types"); + } + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const; + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const; + + template + auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const + { + return [=](auto&&... xs) { + launch(stream, global, local, std::vector{xs...}, zs...); + }; + } + + private: + std::shared_ptr impl; +}; +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp new file mode 100644 index 0000000000000000000000000000000000000000..92edf1262832d5e69c5751b162d9b5a43aac5a58 --- /dev/null +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -0,0 +1,55 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER + +#include +#include + +namespace rtc { +template +struct manage_deleter +{ + template + void operator()(T* x) const + { + if(x != nullptr) + { + (void)f(x); + } + } +}; + +struct null_deleter +{ + template + void operator()(T*) const + { + } +}; + +template +using manage_ptr = std::unique_ptr>; + +template +struct element_type +{ + using type = typename T::element_type; +}; + +template +using remove_ptr = typename std:: + conditional_t{}, std::remove_pointer, element_type>::type; + +template +using shared = std::shared_ptr>; + +template +shared share(T p) +{ + return shared{std::move(p)}; +} + +#define RTC_MANAGE_PTR(T, F) rtc::manage_ptr, decltype(&F), &F> + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0a2cb9b77480f7c32fb531f77a8ad049024dab2 --- /dev/null +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -0,0 +1,24 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR + +#include +#include + +namespace rtc { + +struct tmp_dir +{ + fs::path path; + tmp_dir(const std::string& prefix = ""); + + void execute(const std::string& cmd) const; + + tmp_dir(tmp_dir const&) = delete; + tmp_dir& operator=(tmp_dir const&) = delete; + + ~tmp_dir(); +}; + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cb71b9043cb92c675ce421d668f95a8886291c2 --- /dev/null +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace rtc { + +template +T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) +{ + std::ifstream is(filename, std::ios::binary | std::ios::ate); + if(nbytes == 0) + { + // if there is a non-zero offset and nbytes is not set, + // calculate size of remaining bytes to read + nbytes = is.tellg(); + if(offset > nbytes) + throw std::runtime_error("offset is larger than file size"); + nbytes -= offset; + } + if(nbytes < 1) + throw std::runtime_error("Invalid size for: " + filename); + is.seekg(offset, std::ios::beg); + + T buffer(nbytes, 0); + if(not is.read(&buffer[0], nbytes)) + throw std::runtime_error("Error reading file: " + filename); + return buffer; +} + +std::vector read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0) +{ + return generic_read_file>(filename, offset, nbytes); +} + +std::string read_string(const std::string& filename) +{ + return generic_read_file(filename); +} + +void write_buffer(const std::string& filename, const char* buffer, std::size_t size) +{ + std::ofstream os(filename); + os.write(buffer, size); +} +void write_buffer(const std::string& filename, const std::vector& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} +void write_string(const std::string& filename, const std::string_view& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} + +std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } +// TODO: undo after extracting the codeobj +// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } + +kernel compile_kernel(const std::vector& srcs, compile_options options) +{ + assert(not srcs.empty()); + tmp_dir td{"compile"}; + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " --offload-arch=" + get_device_name(); + std::string out; + + for(const auto& src : srcs) + { + fs::path full_path = td.path / src.path; + fs::path parent_path = full_path.parent_path(); + fs::create_directories(parent_path); + write_string(full_path.string(), src.content); + if(src.path.extension().string() == ".cpp") + { + options.flags += " -c " + src.path.filename().string(); + if(out.empty()) + out = src.path.stem().string() + ".o"; + } + } + + options.flags += " -o " + out; + td.execute(compiler() + options.flags); + + auto out_path = td.path / out; + if(not fs::exists(out_path)) + throw std::runtime_error("Output file missing: " + out); + + auto obj = read_buffer(out_path.string()); + + std::ofstream ofh("obj.o", std::ios::binary); + for(auto i : obj) + ofh << i; + ofh.close(); + // int s = std::system(("/usr/bin/cp " + out_path.string() + " codeobj.bin").c_str()); + // assert(s == 0); + return kernel{obj.data(), options.kernel_name}; +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp new file mode 100644 index 0000000000000000000000000000000000000000..747f83e3baa240159adcf2e89847f4a1bad245a8 --- /dev/null +++ b/codegen/test/rtc/src/hip.cpp @@ -0,0 +1,106 @@ +#include +#include +#include +#include +#include + +namespace rtc { + +using hip_ptr = RTC_MANAGE_PTR(void, hipFree); + +std::string hip_error(int error) { return hipGetErrorString(static_cast(error)); } + +int get_device_id() +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + throw std::runtime_error("No device"); + return device; +} + +std::string get_device_name() +{ + hipDeviceProp_t props{}; + auto status = hipGetDeviceProperties(&props, get_device_id()); + if(status != hipSuccess) + throw std::runtime_error("Failed to get device properties"); + return props.gcnArchName; +} + +bool is_device_ptr(const void* ptr) +{ + hipPointerAttribute_t attr; + auto status = hipPointerGetAttributes(&attr, ptr); + if(status != hipSuccess) + return false; + return attr.type == hipMemoryTypeDevice; +} + +void gpu_sync() +{ + auto status = hipDeviceSynchronize(); + if(status != hipSuccess) + throw std::runtime_error("hip device synchronization failed: " + hip_error(status)); +} + +std::size_t get_available_gpu_memory() +{ + size_t free; + size_t total; + auto status = hipMemGetInfo(&free, &total); + if(status != hipSuccess) + { + std::cerr << "Failed getting available memory: " + hip_error(status) << std::endl; + return (8ull * 1024ull * 1024ull * 1024ull); + } + return free; +} + +std::shared_ptr allocate_gpu(std::size_t sz, bool host) +{ + if(sz > get_available_gpu_memory()) + throw std::runtime_error("Memory not available to allocate buffer: " + std::to_string(sz)); + void* alloc_ptr = nullptr; + auto status = host ? hipHostMalloc(&alloc_ptr, sz) : hipMalloc(&alloc_ptr, sz); + if(status != hipSuccess) + { + if(host) + throw std::runtime_error("Gpu allocation failed: " + hip_error(status)); + else + return allocate_gpu(sz, true); + } + assert(alloc_ptr != nullptr); + std::shared_ptr result = share(hip_ptr{alloc_ptr}); + return result; +} + +std::shared_ptr write_to_gpu(const void* x, std::size_t sz, bool host) +{ + gpu_sync(); + auto result = allocate_gpu(sz, host); + assert(is_device_ptr(result.get())); + assert(not is_device_ptr(x)); + auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); + if(status != hipSuccess) + throw std::runtime_error("Copy to gpu failed: " + hip_error(status)); + return result; +} + +std::shared_ptr read_from_gpu(const void* x, std::size_t sz) +{ + gpu_sync(); + std::shared_ptr result(new char[sz]); + assert(not is_device_ptr(result.get())); + if(not is_device_ptr(x)) + { + throw std::runtime_error( + "read_from_gpu() requires Src buffer to be on the GPU, Copy from gpu failed\n"); + } + auto status = hipMemcpy(result.get(), x, sz, hipMemcpyDeviceToHost); + if(status != hipSuccess) + throw std::runtime_error("Copy from gpu failed: " + hip_error(status)); // NOLINT + return std::static_pointer_cast(result); +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9fe38e84ad6624bcb82d8f3f97a0767ecd92108c --- /dev/null +++ b/codegen/test/rtc/src/kernel.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include + +// extern declare the function since hip/hip_ext.h header is broken +extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + size_t, + hipStream_t, + void**, + void**, + hipEvent_t = nullptr, + hipEvent_t = nullptr, + uint32_t = 0); + +namespace rtc { + +std::vector pack_args(const std::vector& args) +{ + std::vector kernargs; + for(auto&& arg : args) + { + std::size_t n = arg.size; + const auto* p = static_cast(arg.data); + // Insert padding + std::size_t padding = (arg.align - (kernargs.size() % arg.align)) % arg.align; + kernargs.insert(kernargs.end(), padding, 0); + kernargs.insert(kernargs.end(), p, p + n); + } + return kernargs; +} + +using hip_module_ptr = RTC_MANAGE_PTR(hipModule_t, hipModuleUnload); + +struct kernel_impl +{ + hip_module_ptr module = nullptr; + hipFunction_t fun = nullptr; +}; + +hip_module_ptr load_module(const char* image) +{ + hipModule_t raw_m; + auto status = hipModuleLoadData(&raw_m, image); + hip_module_ptr m{raw_m}; + if(status != hipSuccess) + throw std::runtime_error("Failed to load module: " + hip_error(status)); + return m; +} + +kernel::kernel(const char* image, const std::string& name) : impl(std::make_shared()) +{ + impl->module = load_module(image); + auto status = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str()); + if(hipSuccess != status) + throw std::runtime_error("Failed to get function: " + name + ": " + hip_error(status)); +} + +void launch_kernel(hipFunction_t fun, + hipStream_t stream, + std::size_t global, + std::size_t local, + void* kernargs, + std::size_t size) +{ + assert(global > 0); + assert(local > 0); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END}; + + auto status = hipExtModuleLaunchKernel(fun, + global, + 1, + 1, + local, + 1, + 1, + 0, + stream, + nullptr, + reinterpret_cast(&config), + nullptr, + nullptr); + if(status != hipSuccess) + throw std::runtime_error("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const +{ + assert(impl != nullptr); + void* kernargs = args.data(); + std::size_t size = args.size() * sizeof(void*); + + launch_kernel(impl->fun, stream, global, local, kernargs, size); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e89bc35399075d67dcbc03621c1445c6eb6f66b --- /dev/null +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include + +namespace rtc { +std::string random_string(std::string::size_type length) +{ + static const std::string& chars = "0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::mt19937 rg{std::random_device{}()}; + std::uniform_int_distribution pick(0, chars.length() - 1); + + std::string str(length, 0); + std::generate(str.begin(), str.end(), [&] { return chars[pick(rg)]; }); + + return str; +} + +std::string unique_string(const std::string& prefix) +{ + auto pid = getpid(); + auto tid = std::this_thread::get_id(); + auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); + std::stringstream ss; + ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16); + return ss.str(); +} + +tmp_dir::tmp_dir(const std::string& prefix) + : path(fs::temp_directory_path() / + unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix)) +{ + fs::create_directories(this->path); +} + +void tmp_dir::execute(const std::string& cmd) const +{ + std::string s = "cd " + path.string() + "; " + cmd; + std::system(s.c_str()); +} + +tmp_dir::~tmp_dir() { fs::remove_all(this->path); } + +} // namespace rtc diff --git a/dev-requirements.txt b/dev-requirements.txt index 9e7b9f01e1cae3901bd0f1661f39bc383300ebd6..ca883c19e1f527e094fd8e83bffe962a8de77be5 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ -ROCmSoftwarePlatform/rocm-recipes -RadeonOpenCompute/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build -danmar/cppcheck@2.9 \ No newline at end of file +ROCm/rocm-recipes +ROCm/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build +danmar/cppcheck@2.9 diff --git a/docs/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst deleted file mode 100644 index f21d43c5939ab07d3e9b6cc2c6c361a782b13ff1..0000000000000000000000000000000000000000 --- a/docs/API_Reference_Guide.rst +++ /dev/null @@ -1,52 +0,0 @@ - -******************* -API Reference Guide -******************* - -================= -Introduction -================= - -This document contains details of the APIs for the Composable Kernel (CK) library and introduces -some of the key design principles that are used to write new classes that extend CK functionality. - -================= -Using CK API -================= - -This section describes how to use the CK library API. - -================= -CK Datatypes -================= - ------------------ -DeviceMem ------------------ - -.. doxygenstruct:: DeviceMem - ---------------------------- -Kernels For Flashattention ---------------------------- - -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists -the classes that are used in the CK GPU implementation of Flashattention. - -**Gridwise classes** - -.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle - -**Blockwise classes** - -.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 - -.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 - -.. doxygenstruct:: ck::BlockwiseSoftmax - -**Threadwise classes** - -.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic - -.. bibliography:: diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index b2ddff398ce8efcd0c5d8be275d9dd09315d4349..3788ba609c833084bda41819b1109f0c2b3dbc2e 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -1,8 +1,105 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _contributing-to: + +******************************************************************** +Contributor's guide +******************************************************************** + +This chapter explains the rules for contributing to the Composable Kernel project, and how to contribute. + +Getting started +=============== + +#. **Documentation:** Before contributing to the library, familiarize yourself with the + `Composable Kernel User Guide `_. + It provides insight into the core concepts, environment configuration, and steps to obtain or + build the library. You can also find some of this information in the + `README file `_ + on the project's GitHub page. +#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code `_ provides a deeper understanding of the CK library and showcases its performance capabilities. + `_ + from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities. +#. **General information:** For broader information about AMD products, consider exploring the + `AMD Developer Central portal `_. + +How to contribute =================== -Contributor's Guide -=================== -Pull-request guidelines -======================= +You can make an impact by reporting issues or proposing code enhancements through pull requests. + +Reporting issues +---------------- + +Use `Github issues `_ +to track public bugs and enhancement requests. + +If you encounter an issue with the library, please check if the problem has already been +reported by searching existing issues on GitHub. If your issue seems unique, please submit a new +issue. All reported issues must include: + +* A comprehensive description of the problem, including: + + * What did you observe? + * Why do you think it is a bug (if it seems like one)? + * What did you expect to happen? What would indicate the resolution of the problem? + * Are there any known workarounds? + +* Your configuration details, including: + + * Which GPU are you using? + * Which OS version are you on? + * Which ROCm version are you using? + * Are you using a Docker image? If so, which one? + +* Steps to reproduce the issue, including: + + * What actions trigger the issue? What are the reproduction steps? + + * If you build the library from scratch, what CMake command did you use? + + * How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue? + +Before submitting any issue, ensure you have addressed all relevant questions from the checklist. + +Creating Pull Requests +---------------------- + +You can submit `Pull Requests (PR) on GitHub +`_. + +All contributors are required to develop their changes on a separate branch and then create a +pull request to merge their changes into the `develop` branch, which is the default +development branch in the Composable Kernel project. All external contributors must use their own +forks of the project to develop their changes. + +When submitting a Pull Request you should: + +* Describe the change providing information about the motivation for the change and a general + description of all code modifications. + +* Verify and test the change: + + * Run any relevant existing tests. + * Write new tests if added functionality is not covered by current tests. + +* Ensure your changes align with the coding style defined in the ``.clang-format`` file located in + the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We + highly recommend contributors utilize this method to maintain consistent code formatting. + Instructions on setting up `pre-commit` can be found in the project's + `README file `_ + +* Link your PR to any related issues: + + * If there is an issue that is resolved by your change, please provide a link to the issue in + the description of your pull request. + +* For larger contributions, structure your change into a sequence of smaller, focused commits, each + addressing a particular aspect or fix. + +Following the above guidelines ensures a seamless review process and faster assistance from our +end. -[TODO] +Thank you for your commitment to enhancing the Composable Kernel project! diff --git a/docs/conceptual/what-is-ck.rst b/docs/conceptual/what-is-ck.rst new file mode 100644 index 0000000000000000000000000000000000000000..36785fc6ca68135bf32bccf363f1ce41cfaa17ad --- /dev/null +++ b/docs/conceptual/what-is-ck.rst @@ -0,0 +1,41 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _what-is-ck: + +******************************************************************** +What is the Composable Kernel library +******************************************************************** + + +Methodology +=========== + +The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. + +CK utilizes two concepts to achieve performance portability and code maintainability: + +* A tile-based programming model +* Algorithm complexity reduction for complex ML operators using an innovative technique called + "Tensor Coordinate Transformation". + +.. image:: ../data/ck_component.png + :alt: CK Components + + +Code Structure +============== + +The CK library is structured into 4 layers: + +* "Templated Tile Operators" layer +* "Templated Kernel and Invoker" layer +* "Instantiated Kernel and Invoker" layer +* "Client API" layer + +It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. + +.. image:: ../data/ck_layer.png + :alt: CK Layers + \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 0de590da1a53daba8dc0e379e68abbae250e0667..e8617a09ef9168fedf9214bf7094744fb8648a8e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,23 +4,34 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -import subprocess +import re from rocm_docs import ROCmDocs +html_theme_options = {"flavor": "list"} -name = "Composable Kernel" -get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt' -version = subprocess.getoutput(get_version) -if len(version) > 0: - name = f"{name} {version}" +with open('../CMakeLists.txt', encoding='utf-8') as f: + match = re.search(r'.*set\(version ([0-9.]+)[^0-9.]+', f.read()) + if not match: + raise ValueError("VERSION not found!") + version_number = match[1] +left_nav_title = f"Composable Kernel {version_number} Documentation" + +# for PDF output on Read the Docs +project = "Composable Kernel Documentation" +author = "Advanced Micro Devices, Inc." +copyright = "Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved." +version = version_number +release = version_number external_toc_path = "./sphinx/_toc.yml" -docs_core = ROCmDocs(f"{name} Documentation") -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml") +docs_core = ROCmDocs(left_nav_title) +docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") docs_core.setup() +external_projects_current_project = "composable_kernel" + mathjax3_config = { 'tex': { 'macros': { @@ -34,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS: extensions += ['sphinxcontrib.bibtex'] bibtex_bibfiles = ['refs.bib'] + +cpp_id_attributes = ["__global__", "__device__", "__host__"] diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst deleted file mode 100644 index 66ec91096e01b9ebfbe78559b9a168711a7b4206..0000000000000000000000000000000000000000 --- a/docs/dockerhub.rst +++ /dev/null @@ -1,99 +0,0 @@ -=================== -CK Docker Hub -=================== - -------------------------------------- -Why do I need this? -------------------------------------- - -To make our lives easier and bring Composable Kernel dependencies together, we recommend using -docker images that can be found on `Docker Hub `_. - -------------------------------------- -So what is Composable Kernel? -------------------------------------- - -Composable Kernel (CK) library aims to provide a programming model for writing performance critical -kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, -through general purpose kernel languages, like HIP C++. - -To get the CK library:: - - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git - - -run a docker container:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.6 \ - /bin/bash - -and build the CK:: - - mkdir build && cd build - # Need to specify target ID, example below is for gfx908 and gfx90a - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx908;gfx90a" \ - .. - -and:: - - make -j examples tests - -To run all the test cases including tests and examples run:: - - make test - -We can also run specific examples or tests like:: - - ./bin/example_gemm_xdl_fp16 - ./bin/test_gemm_fp16 - -For more details visit `CK github repository `_, -`CK examples `_, -`even more CK examples `_. - -------------------------------------- -And what is inside? -------------------------------------- - -The docker images have everything you need for running CK including: - -* `ROCm `_ -* `CMake `_ -* `Compiler `_ - -------------------------------------- -Which image is right for me? -------------------------------------- - -Let's take a look at the image naming, for example ``ck_ub20.04_rocm5.6``. The image specs are: - -* ``ck`` - made for running Composable Kernel; -* ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm5.6`` - ROCm platform version 5.6. - -So just pick the right image for your project dependencies and you're all set. - -------------------------------------- -DIY starts here -------------------------------------- - -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the -`Dockerfile `_ -for your needs. - -------------------------------------- -License -------------------------------------- - -CK is released under the MIT `license `_. diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 1084f94c81bbef88a3373d6e9e5d24b89e7b5353..fac9e138e1bc59a8ea37afceebbdedb9ff03e084 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -58,7 +58,7 @@ PROJECT_LOGO = # entered, it will be relative to the location where doxygen was started. If # left blank the current directory will be used. -OUTPUT_DIRECTORY = docBin +OUTPUT_DIRECTORY = . # If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- # directories (in 2 levels) under the output directory of each output format and @@ -778,7 +778,9 @@ WARN_LOGFILE = INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ - ../../library/include/ck/library/utility + ../../library/include/ck/library/utility \ + ../../include/ck/wrapper + # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/index.rst b/docs/index.rst index 51c0c862ae3e3e8b314aa37bd63b6d762b895607..30ef672f84f0d7efbe1900a418b4944552353a47 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,55 +1,38 @@ -============================ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _composable-kernel: + +******************************************************************** Composable Kernel User Guide -============================ +******************************************************************** ------------- -Introduction ------------- +The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. This document contains instructions for installing, using, and contributing to the Composable Kernel project. To learn more see :ref:`what-is-ck`. -This document contains instructions for installing, using, and contributing to Composable Kernel (CK). +The CK documentation is structured as follows: ------------ -Methodology ------------ +.. grid:: 2 + :gutter: 3 -Composable Kernel (CK) library aims to provide a programming model for writing performance critical -kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, -through general purpose kernel languages, like HIP C++. + .. grid-item-card:: Installation -CK utilizes two concepts to achieve performance portability and code maintainability: + * :ref:`docker-hub` -* A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call - "Tensor Coordinate Transformation". + .. grid-item-card:: Conceptual -.. image:: data/ck_component.png - :alt: CK Components + * :ref:`what-is-ck` --------------- -Code Structure --------------- + .. grid-item-card:: API reference -Current CK library are structured into 4 layers: + * :ref:`supported-primitives` + * :ref:`api-reference` + * :ref:`wrapper` -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer + .. grid-item-card:: Tutorial -.. image:: data/ck_layer.png - :alt: CK Layers - -Documentation Roadmap -^^^^^^^^^^^^^^^^^^^^^ -The following is a list of CK documents in the suggested reading order: + * :ref:`hello-world` -.. toctree:: - :maxdepth: 5 - :caption: Contents: - :numbered: +To contribute to the documentation refer to `Contributing to ROCm `_. - tutorial_hello_world - dockerhub - Supported_Primitives_Guide - API_Reference_Guide - Contributors_Guide +You can find licensing information on the `Licensing `_ page. diff --git a/docs/install/dockerhub.rst b/docs/install/dockerhub.rst new file mode 100644 index 0000000000000000000000000000000000000000..87eb5a4f81393a184cc699c3659683ae0286cdcc --- /dev/null +++ b/docs/install/dockerhub.rst @@ -0,0 +1,101 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _docker-hub: + +******************************************************************** +CK Docker Hub +******************************************************************** + +Why do I need this? +=================== + +To make things simpler, and bring Composable Kernel and its dependencies together, +docker images can be found on `Docker Hub `_. Docker images provide a complete image of the OS, the Composable Kernel library, and its dependencies in a single downloadable file. + +Refer to `Docker Overview `_ for more information on Docker images and containers. + +Which image is right for me? +============================ + +The image naming includes information related to the docker image. +For example ``ck_ub20.04_rocm6.0`` indicates the following: + +* ``ck`` - made for running Composable Kernel; +* ``ub20.04`` - based on Ubuntu 20.04; +* ``rocm6.0`` - ROCm platform version 6.0. + +Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. Use the ``docker pull`` command to download the file:: + + docker pull rocm/composable_kernel:ck_ub20.04_rocm6.0 + + +What is inside the image? +------------------------- + +The docker images have everything you need for running CK including: + +* `ROCm `_ +* `CMake `_ +* `Compiler `_ +* `Composable Kernel library `_ + +Running the docker container +============================ + +After downloading the docker image, you can start the container using one of a number of commands. Start with the ``docker run`` command as shown below:: + + docker run \ + -it \ + --privileged \ + --group-add sudo \ + -w /root/workspace \ + -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ + rocm/composable_kernel:ck_ub20.04_rocm6.0 \ + /bin/bash + +After starting the bash shell, the docker container current folder is `~/workspace`. The library path is ``~/workspace/composable_kernel``. Navigate to the library to begin the tutorial as explained in :ref:`hello-world`: + +.. note:: + + If your current folder is different from `${HOME}`, adjust the line ``-v ${HOME}:/root/workspace`` in the ``docker run`` command to fit your folder structure. + +Stop and restart the docker image +================================= + +After finishing the tutorial, or just when you have completed your work session, you can close the docker container, or stop the docker container to restart it at another time. Closing the docker container means that it is still in the active state, and can be resumed from where you left it. Stopping the container closes it, and returns the image to its initial state. + +Use the ``Ctrl-D`` option to exit the container, while leaving it active, so you can return to the container in its current state to resume the tutorial, or pickup your project where you left off. + +To restart the active container use the ``docker exec`` command to specify the container name and options as follows:: + + docker exec -it bash + +Where: + +* `exec` is the docker command +* `-it` is the interactive option for `exec` +* `` specifies an active container on the system +* `bash` specifies the command to run in the interactive shell + +.. note:: + + You can use the ``docker container ls`` command to list the active containers on the system. + +To start a container from the image, use the ``docker start`` command:: + + docker start + +Then use the docker exec command as shown above to start the bash shell. + +Use the ``docker stop`` command to stop the container and restore the image to its initial state:: + + docker stop + +Editing the docker image +======================= + +If you want to customize the docker image, edit the +`Dockerfile `_ +from the GitHub repository to suit your needs. diff --git a/docs/license.rst b/docs/license.rst index ddb544496e41b92dcaadb6ff816946e46ca16da8..1e5389ccc10b944fd61fc546b41d8b27af22a5f0 100644 --- a/docs/license.rst +++ b/docs/license.rst @@ -1,6 +1,11 @@ -======= +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _license: + +******************************************************************** License -======= +******************************************************************** -.. include:: ../LICENSE - :literal: +.. include:: ../LICENSE \ No newline at end of file diff --git a/docs/reference/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst new file mode 100644 index 0000000000000000000000000000000000000000..0d2d41c1ebca11789ed21176f5ebdb9b4b038ebc --- /dev/null +++ b/docs/reference/API_Reference_Guide.rst @@ -0,0 +1,48 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _api-reference: + +******************************************************************** +API reference guide +******************************************************************** + + +This document contains details of the APIs for the Composable Kernel (CK) library and introduces +some of the key design principles that are used to write new classes that extend CK functionality. + +================= +CK Datatypes +================= + +----------------- +DeviceMem +----------------- + +.. doxygenstruct:: DeviceMem + +--------------------------- +Kernels For Flashattention +--------------------------- + +The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists +the classes that are used in the CK GPU implementation of Flashattention. + +**Gridwise classes** + +.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle + +**Blockwise classes** + +.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 + +.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 + +.. doxygenstruct:: ck::BlockwiseSoftmax + +**Threadwise classes** + +.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic + +.. bibliography:: diff --git a/docs/Supported_Primitives_Guide.rst b/docs/reference/Supported_Primitives_Guide.rst similarity index 79% rename from docs/Supported_Primitives_Guide.rst rename to docs/reference/Supported_Primitives_Guide.rst index 3462283d90d62f1a181a1cc26d04a36f9d6f4fff..e24acf56562a7d6f18a525dad6f63f2ef63eb9c0 100644 --- a/docs/Supported_Primitives_Guide.rst +++ b/docs/reference/Supported_Primitives_Guide.rst @@ -1,16 +1,20 @@ -========================== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _supported-primitives: + +******************************************************************** Supported Primitives Guide -========================== +******************************************************************** -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the -API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins -the algorithms implemented in CK. +This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. ------------ Softmax ------------ -For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the +For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` you can decompose the softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, .. math:: @@ -27,7 +31,7 @@ where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and :math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar. For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size -:math:`B_r \times B_c` we can compute the row-wise softmax as follows. +:math:`B_r \times B_c` you can compute the row-wise softmax as follows. For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate, diff --git a/docs/reference/wrapper.rst b/docs/reference/wrapper.rst new file mode 100644 index 0000000000000000000000000000000000000000..190fbcd445d8ce5ea3f261c862c0706dd578f092 --- /dev/null +++ b/docs/reference/wrapper.rst @@ -0,0 +1,94 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _wrapper: + +******************************************************************** +Wrapper +******************************************************************** + +------------------------------------- +Description +------------------------------------- + + +The CK library provides a lightweight wrapper for more complex operations implemented in +the library. + +Example: + +.. code-block:: c + + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::array data; + auto tensor = ck::wrapper::make_tensor(&data[0], layout); + + for(ck::index_t w = 0; w < size(tensor); w++) { + tensor(w) = w; + } + + // slice() == slice(0, -1) (whole dimension) + auto tensor_slice = tensor(ck::wrapper::slice(1, 3), ck::make_tuple(ck::wrapper::slice(), ck::wrapper::slice())); + std::cout << "dims:2,(2,4) strides:2,(1,8)" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(tensor_slice); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(tensor_slice); w++) + { + std::cout << tensor_slice(h, w) << " "; + } + std::cout << std::endl; + } + +Output:: + + dims:2,(2,4) strides:2,(1,8) + 1 5 9 13 17 21 25 29 + 2 6 10 14 18 22 26 30 + + +Tutorials: + +* `GEMM tutorial `_ + +Advanced examples: + +* `Image to column `_ +* `Basic gemm `_ +* `Optimized gemm `_ + +------------------------------------- +Layout +------------------------------------- + +.. doxygenstruct:: Layout + +------------------------------------- +Layout helpers +------------------------------------- + +.. doxygenfile:: include/ck/wrapper/utils/layout_utils.hpp + +------------------------------------- +Tensor +------------------------------------- + +.. doxygenstruct:: Tensor + +------------------------------------- +Tensor helpers +------------------------------------- + +.. doxygenfile:: include/ck/wrapper/utils/tensor_utils.hpp + +.. doxygenfile:: include/ck/wrapper/utils/tensor_partition.hpp + +------------------------------------- +Operations +------------------------------------- + +.. doxygenfile:: include/ck/wrapper/operations/copy.hpp +.. doxygenfile:: include/ck/wrapper/operations/gemm.hpp diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 83dd1e7b1a5d759143f2db13a1e296f5e66da96d..533b81cd39355b0c438a3655eebab58aa9513cbd 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -1,10 +1,36 @@ -# Anywhere {branch} is used, the branch name will be substituted. -# These comments will also be removed. defaults: numbered: False - maxdepth: 6 root: index subtrees: - - caption: About - entries: - - file: license + +- caption: Conceptual + entries: + - file: conceptual/what-is-ck.rst + title: What is Composable Kernel? + +- caption: Install + entries: + - file: install/dockerhub.rst + title: Docker Hub + +- caption: CK API Reference + entries: + - file: reference/Supported_Primitives_Guide.rst + title: Supported Primitives + - file: reference/API_Reference_Guide.rst + title: API Reference + - file: reference/wrapper.rst + title: Wrapper + +- caption: Tutorial + entries: + - file: tutorial/tutorial_hello_world.rst + title: Hello World Tutorial + +- caption: About + entries: + - file: Contributors_Guide.rst + title: Contributing to CK + - file: license.rst + title: License + \ No newline at end of file diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index a12a00a2b21ec12d8d6f41d0838309350d2eefd1..3a2e266ef5ef48470ec10df18eb11cabec34104d 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core>=0.20.0 -sphinxcontrib-bibtex==2.5.0 +rocm-docs-core==1.8.5 +sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index d5f67eeb585b3b457ee49bf5b1b46468f76e742a..b65d2391f65da7295f84d801b0aae082c23769b1 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -1,36 +1,36 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile requirements.in # -accessible-pygments==0.0.3 +accessible-pygments==0.0.5 # via pydata-sphinx-theme -alabaster==0.7.13 +alabaster==0.7.16 # via sphinx -babel==2.12.1 +babel==2.15.0 # via # pydata-sphinx-theme # sphinx -beautifulsoup4==4.11.2 +beautifulsoup4==4.12.3 # via pydata-sphinx-theme -breathe==4.34.0 +breathe==4.35.0 # via rocm-docs-core -certifi==2022.12.7 +certifi==2024.7.4 # via requests -cffi==1.15.1 +cffi==1.16.0 # via # cryptography # pynacl -charset-normalizer==3.1.0 +charset-normalizer==3.3.2 # via requests -click==8.1.3 +click==8.1.7 # via sphinx-external-toc -cryptography==40.0.2 +cryptography==43.0.0 # via pyjwt -deprecated==1.2.13 +deprecated==1.2.14 # via pygithub -docutils==0.16 +docutils==0.21.2 # via # breathe # myst-parser @@ -38,35 +38,35 @@ docutils==0.16 # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex -fastjsonschema==2.18.0 +fastjsonschema==2.20.0 # via rocm-docs-core -gitdb==4.0.10 +gitdb==4.0.11 # via gitpython -gitpython==3.1.31 +gitpython==3.1.43 # via rocm-docs-core -idna==3.4 +idna==3.7 # via requests imagesize==1.4.1 # via sphinx -jinja2==3.1.2 +jinja2==3.1.4 # via # myst-parser # sphinx -latexcodec==2.0.1 +latexcodec==3.0.0 # via pybtex -markdown-it-py==2.2.0 +markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser -markupsafe==2.1.2 +markupsafe==2.1.5 # via jinja2 -mdit-py-plugins==0.3.5 +mdit-py-plugins==0.4.1 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-parser==1.0.0 +myst-parser==3.0.1 # via rocm-docs-core -packaging==23.0 +packaging==24.1 # via # pydata-sphinx-theme # sphinx @@ -74,48 +74,46 @@ pybtex==0.24.0 # via # pybtex-docutils # sphinxcontrib-bibtex -pybtex-docutils==1.0.2 +pybtex-docutils==1.0.3 # via sphinxcontrib-bibtex -pycparser==2.21 +pycparser==2.22 # via cffi -pydata-sphinx-theme==0.13.3 +pydata-sphinx-theme==0.15.4 # via # rocm-docs-core # sphinx-book-theme -pygithub==1.58.2 +pygithub==2.3.0 # via rocm-docs-core -pygments==2.14.0 +pygments==2.18.0 # via # accessible-pygments # pydata-sphinx-theme # sphinx -pyjwt[crypto]==2.6.0 +pyjwt[crypto]==2.8.0 # via pygithub pynacl==1.5.0 # via pygithub -pyyaml==6.0 +pyyaml==6.0.1 # via # myst-parser # pybtex # rocm-docs-core # sphinx-external-toc -requests==2.28.2 +requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core>=0.20.0 +rocm-docs-core==1.8.5 # via -r requirements.in six==1.16.0 - # via - # latexcodec - # pybtex -smmap==5.0.0 + # via pybtex +smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx -soupsieve==2.4 +soupsieve==2.5 # via beautifulsoup4 -sphinx==5.3.0 +sphinx==7.4.7 # via # breathe # myst-parser @@ -127,33 +125,39 @@ sphinx==5.3.0 # sphinx-external-toc # sphinx-notfound-page # sphinxcontrib-bibtex -sphinx-book-theme==1.0.1 +sphinx-book-theme==1.1.3 # via rocm-docs-core -sphinx-copybutton==0.5.1 +sphinx-copybutton==0.5.2 # via rocm-docs-core -sphinx-design==0.3.0 +sphinx-design==0.6.0 # via rocm-docs-core -sphinx-external-toc==0.3.1 +sphinx-external-toc==1.0.1 # via rocm-docs-core -sphinx-notfound-page==0.8.3 +sphinx-notfound-page==1.0.3 # via rocm-docs-core -sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-bibtex==2.5.0 +sphinxcontrib-bibtex==2.6.3 # via -r requirements.in -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==2.0.0 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-qthelp==2.0.0 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==2.0.0 # via sphinx -typing-extensions==4.5.0 - # via pydata-sphinx-theme -urllib3==1.26.15 - # via requests -wrapt==1.15.0 +tomli==2.0.1 + # via sphinx +typing-extensions==4.12.2 + # via + # pydata-sphinx-theme + # pygithub +urllib3==2.2.2 + # via + # pygithub + # requests +wrapt==1.16.0 # via deprecated diff --git a/docs/tutorial/tutorial_hello_world.rst b/docs/tutorial/tutorial_hello_world.rst new file mode 100644 index 0000000000000000000000000000000000000000..c31460785bd127b92a4b2f53551541b041919102 --- /dev/null +++ b/docs/tutorial/tutorial_hello_world.rst @@ -0,0 +1,165 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _hello-world: + +******************************************************************** +Hello World Tutorial +******************************************************************** + +This tutorial is for engineers dealing with artificial intelligence and machine learning who +would like to optimize pipelines and improve performance using the Composable +Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example. + +Description +=========== + +Modern AI technology solves more and more problems in a variety of fields, but crafting fast and +efficient workflows is still challenging. CK can make the AI workflow fast +and efficient. CK is a collection of optimized AI operator kernels with tools to create +new kernels. The library has components required for modern neural network architectures +including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators. + +CK library acceleration features are based on: + +* Layered structure +* Tile-based computation model +* Tensor coordinate transformation +* Hardware acceleration use +* Support of low precision data types including fp16, bf16, int8 and int4 + +If you need more technical details and benchmarking results read the following +`blog post `_. + +To download the library visit the `composable_kernel repository `_. + +Hardware targets +================ + +CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are +supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture. + +========== ========= +GPU Target AMD GPU +========== ========= +gfx908 Radeon Instinct MI100 +gfx90a Radeon Instinct MI210, MI250, MI250X +gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT +========== ========= + +There are also `cloud options `_ you can find if +you don't have an AMD GPU at hand. + +Build the library +================= + +This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. + +.. note:: + + You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. + +Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: + + cd composable_kernel/ + +Create and change to a ``build`` directory:: + + mkdir build && cd build + +The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag:: + + cmake \ + -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_CXX_FLAGS="-O3" \ + -D CMAKE_BUILD_TYPE=Release \ + -D BUILD_DEV=OFF \ + -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. + +If everything goes well the CMake command will return:: + + -- Configuring done + -- Generating done + -- Build files have been written to: "/root/workspace/composable_kernel/build" + +Finally, you can build examples and tests:: + + make -j examples tests + +When complete you should see:: + + Scanning dependencies of target tests + [100%] Built target tests + +Run examples and tests +====================== + +Examples are listed as test cases as well, so you can run all examples and tests with:: + + ctest + +You can check the list of all tests by running:: + + ctest -N + +You can also run examples separately as shown in the following example execution:: + + ./bin/example_gemm_xdl_fp16 1 1 1 + +The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. + +If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like:: + + a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} + b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} + c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} + Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle LoopScheduler: Interwave, PipelineVersion: v1 + +However, running it on a `gfx1030` device should result in the following:: + + a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} + b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} + c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} + DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem + +Don't worry, some operators are supported on `gfx1030` architecture, so you can run a +separate example like:: + + ./bin/example_gemm_dl_fp16 1 1 1 + +and it should return something like:: + + a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} + b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} + c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} + arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} + arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} + arg.c_grid_desc_m_n_{ 3840, 4096} + launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} + Warm up 1 time + Start running 10 times... + Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> + +.. note:: + + A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture. + +You can also run a separate test:: + + ctest -R test_gemm_fp16 + +If everything goes well you should see something like:: + + Start 121: test_gemm_fp16 + 1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec + + 100% tests passed, 0 tests failed out of 1 + +Summary +======= + +In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task. + +P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance. diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst deleted file mode 100644 index bfb197e085a2a986ae50aea6b6eeebd8d203eab1..0000000000000000000000000000000000000000 --- a/docs/tutorial_hello_world.rst +++ /dev/null @@ -1,210 +0,0 @@ -=============== -CK Hello world -=============== - -------------------------------------- -Motivation -------------------------------------- - -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who -would like to optimize their pipelines and squeeze every performance drop by adding Composable -Kernel (CK) library to their projects. We would like to make the CK library approachable so -the tutorial is not based on the latest release and doesn't have all the bleeding edge features, -but it will be reproducible now and forever. - -During this tutorial we will have an introduction to the CK library, we will build it and run some -examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go -in depth and breadth and get familiar with other tools and ways to integrate CK into your project. - -------------------------------------- -Description -------------------------------------- - -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and -efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast -and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create -new ones. The library has components required for majority of modern neural networks architectures -including matrix multiplication, convolution, contraction, reduction, attention modules, variety of -activation functions, fused operators and many more. - -So how do we (almost) reach the speed of light? CK acceleration abilities are based on: - -* Layered structure. -* Tile-based computation model. -* Tensor coordinate transformation. -* Hardware acceleration use. -* Support of low precision data types including fp16, bf16, int8 and int4. - -If you are excited and need more technical details and benchmarking results - read this awesome -`blog post `_. - -For more details visit our `github repository `_. - -------------------------------------- -Hardware targets -------------------------------------- - -CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are -supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target -GPU architecture. - -========== ========= -GPU Target AMD GPU -========== ========= -gfx908 Radeon Instinct MI100 -gfx90a Radeon Instinct MI210, MI250, MI250X -gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT -========== ========= - -There are also `cloud options `_ you can find if -you don't have an AMD GPU at hand. - -------------------------------------- -Build the library -------------------------------------- - -First let's clone the library and rebase to the tested version:: - - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git - cd composable_kernel/ - git checkout tutorial_hello_world - -To make our lives easier we prepared -`docker images `_ with all the necessary -dependencies. Pick the right image and create a container. In this tutorial we use -``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and -ROCm v5.6. - -If your current folder is ``${HOME}``, start the docker container with:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${HOME}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.6 \ - /bin/bash - -If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` -to fit your folder structure. - -Inside the docker container current folder is ``~/workspace``, library path is -``~/workspace/composable_kernel``, navigate to the library:: - - cd composable_kernel/ - -Create and go to the ``build`` directory:: - - mkdir build && cd build - -In the previous section we talked about target GPU architecture. Once you decide which one is right -for you, run CMake using the right ``GPU_TARGETS`` flag:: - - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D BUILD_DEV=OFF \ - -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. - -If everything went well the CMake run will end up with:: - - -- Configuring done - -- Generating done - -- Build files have been written to: "/root/workspace/composable_kernel/build" - -Finally, we can build examples and tests:: - - make -j examples tests - -If everything is smooth, you'll see:: - - Scanning dependencies of target tests - [100%] Built target tests - ---------------------------- -Run examples and tests ---------------------------- - -Examples are listed as test cases as well, so we can run all examples and tests with:: - - ctest - -You can check the list of all tests by running:: - - ctest -N - -We can also run them separately, here is a separate example execution:: - - ./bin/example_gemm_xdl_fp16 1 1 1 - -The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU, -initialize matrices with integers and benchmark the kernel execution. You can play around with -these parameters and see how output and execution results change. - -If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see -something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 - -Meanwhile, running it on a `gfx1030` device should result in:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem - -But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a -separate example like:: - - ./bin/example_gemm_dl_fp16 1 1 1 - -and it should result in something nice similar to:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} - arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} - arg.c_grid_desc_m_n_{ 3840, 4096} - launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> - -.. note:: - - There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of - the newest versions of the library and do not see the above results when running - ``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command - in order to build the operators supported on the `gfx1030` architecture. - -We can also run a separate test:: - - ctest -R test_gemm_fp16 - -If everything goes well you should see something like:: - - Start 121: test_gemm_fp16 - 1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec - - 100% tests passed, 0 tests failed out of 1 - ------------ -Summary ------------ - -In this tutorial we took the first look at the Composable Kernel library, built it on your system -and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different -configs to find out the best one for your hardware and task. - -P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better -ways to spend your money for sure! diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index a5933262a588df5f507b8b006460e6f68a76025e..52c8ab5806454c787bb2e4648785aecdda245274 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,71 +1,89 @@ -if(DL_KERNELS) - add_custom_target(example_gemm_dl) - - add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_fp32) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_fp16) - add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int8) - endif() - - if(USE_BITINT_EXTENSION_INT4) +add_custom_target(example_gemm_dl) + +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_fp32) + +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_fp16) + +add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) + +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_int8) +if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int4) - endif(USE_BITINT_EXTENSION_INT4) -endif() + add_example_dependencies(example_gemm_dl example_gemm_dl_int4) +endif(USE_BITINT_EXTENSION_INT4) add_custom_target(example_gemm_xdl) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) - add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) - add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) - add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) - - if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) - endif() - -endif() - -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -endif() - -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_int8) -endif() +add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) + +add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) + +add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3) +add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) +add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) +add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) +add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) + +add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + +add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) + +add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) + +add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) + +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_int4) + add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -endif() +add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) - if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_f8) - endif() -endif() +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32) + + add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) + set(target 1) + endif() +endforeach() + +add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) + +add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) + +add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) -add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16) +add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8) diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md index 226783b03b0b58cfe4de15a189f07a8324b3bc1d..5edec1f04378963686ab8915b1b3ede93dc76081 100644 --- a/example/01_gemm/README.md +++ b/example/01_gemm/README.md @@ -8,16 +8,20 @@ ./bin/example_gemm_xdl 0 1 5 ``` -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` +# Instructions for ```example_gemm_xdl_fp16_streamk_v3``` + +## Run ```example_gemm_xdl_fp16_streamk_v3``` +```bash +arg1: verification (0=no, 1=yes) +arg2: initialization (0=no init, 1=integer value, 2=decimal value) +arg3: time kernel (0=no, 1=yes) +arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK) +arg11: Grid_size(-1 for max occupancy) +bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1 a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s +problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1} +Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2 ``` diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 7fd15b2833d2d7522fa3d87029493415fd596fad..6e1c9f2a0da41124f67210c9e65d99a6bceca536 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,6 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" struct ProblemSize final { @@ -28,9 +29,9 @@ struct ProblemSize final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; }; struct ProblemSizeStreamK final @@ -39,18 +40,45 @@ struct ProblemSizeStreamK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; ck::index_t NumSKBlocks = -1; }; +struct ProblemSizeStreamK_universal final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; + + ck::index_t Grid_size = -1; // defaults to max occupancy + ck::index_t Streamk_sel = 1; // defaults to 1-tile SK +}; + +struct ProblemSizeSplitK final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; + + ck::index_t KBatch = 1; +}; struct ExecutionConfig final { - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; + // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU + int do_verification = 3; + int init_method = 2; + bool time_kernel = false; }; template @@ -99,7 +127,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -110,6 +138,57 @@ bool parse_cmd_args(int argc, return true; } +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeStreamK_universal& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.Streamk_sel = std::stoi(argv[10]); + problem_size.Grid_size = std::stoi(argv[11]); + } + } + else + { + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + return false; + } + + return true; +} + template <> bool parse_cmd_args(int argc, char* argv[], @@ -147,12 +226,62 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + return false; + } + + return true; +} + +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.KBatch = std::stoi(argv[10]); + } + } + else + { + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: NumSKBlocks(optional)" << std::endl; + << "arg10: KBatch" << std::endl; return false; } diff --git a/example/01_gemm/gemm_dl_dpp8_fp16.cpp b/example/01_gemm/gemm_dl_dpp8_fp16.cpp deleted file mode 100644 index ea0ba390767afd5bb52f2d2b6ce58bb0d6e7f16e..0000000000000000000000000000000000000000 --- a/example/01_gemm/gemm_dl_dpp8_fp16.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = Col; -using BLayout = Row; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 -// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| -// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| -// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index b5fecb97521bd2d4213a9435b01dafe45a43be6e..b9284b2783ad49200452bf0fc8fad7e97da2d352 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 212b72f2a6a060a602c5c640709da77684c7c55e..16842136412018885c497f2e54a304492362e7de 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e55ae140130c0779c1e81e0bdc84b6724c6bac86..43c0cfe2e014245da149cd6470e1807c7222e5fc 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 1840390aa9e02e85c88328baf358546db362ab24..1e64e9a0a33142ce70108bd2e0949668e6622ca8 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30faf542dddd4b3f713cadfb0187f3ed5a38dc42 --- /dev/null +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>; +// // clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_bf16.cpp b/example/01_gemm/gemm_wmma_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a87426094f1af4329d421770ab707a4b13fc4fdf --- /dev/null +++ b/example/01_gemm/gemm_wmma_bf16.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index b11fe76ab2ce031708be7addd7f84be0f15adc6b..28ab878ac3a1a199f1ca52ffeaacbb352b31d0c9 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -19,20 +19,66 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_int8.cpp b/example/01_gemm/gemm_wmma_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a88e42d42b136c0bd28d1bd20b7263d65258e8b5 --- /dev/null +++ b/example/01_gemm/gemm_wmma_int8.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using CDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 3cac55ef4702856dd08dc223c02ea25113dbbf32..6cfff30dbd9329cc44cb47b65917a7aaa0f86986 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..108c100cbdf88c0a8e33e9d372daddf2a56894ab --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/utility/type_convert.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_v3.cpp b/example/01_gemm/gemm_xdl_bf16_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e538aee1fe3899fb1374a731df4b0e50f89a6473 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, + 64, 8, 8, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_f8.cpp b/example/01_gemm/gemm_xdl_f8.cpp deleted file mode 100644 index 10159267776f6f33469affc618192e29d62763ac..0000000000000000000000000000000000000000 --- a/example/01_gemm/gemm_xdl_f8.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; -using CDataType = ck::f8_t; -using AccDataType = float; -using CShuffleDataType = ck::f8_t; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 54fbd9cdd4983c1a810f3c8a1ac1521b363af690..07d51855d6b60bfa32b4e815627dfbe64a0006d8 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -9,13 +9,13 @@ using ADataType = ck::half_t; using BDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = float; +using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using F16 = ck::half_t; using ALayout = Row; -using BLayout = Col; +using BLayout = Row; using CLayout = Row; using AElementOp = PassThrough; @@ -30,7 +30,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; // // clang-format on // clang-format off @@ -39,7 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance1; @@ -47,6 +47,17 @@ using DeviceGemmInstance = DeviceGemmInstance1; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_f8.cpp b/example/01_gemm/gemm_xdl_fp16_f8.cpp deleted file mode 100644 index d3cf3d397a5fd329a58dc3a967aa4c8f0bb2b696..0000000000000000000000000000000000000000 --- a/example/01_gemm/gemm_xdl_fp16_f8.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = ck::half_t; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto LoopSched = ck::make_default_loop_scheduler(); -static constexpr auto PipelineVer = ck::PipelineVersion::v1; -using ComputeType = ck::half_t; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a996d034e6a9d6bc6243c9bfb4c3db98bdee4b29 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeType = ck::half_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2e27fc66f9456b67ff860e5926f6a24149ab14c3 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 64, + 16, 16, + 64, 16, 8, + 16, 16, + 1, 1, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b163962b95132e2749151f88c682968417fd361 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 64, 8, 2, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 2, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ecd3b7be5dafca9b7fec16162f3ffeaf1ca92700 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using F16 = ck::half_t; +using F32 = float; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV2< + ALayout, BLayout, CLayout, + F16, F16, F16, F32, F16, + PassThrough, PassThrough, PassThrough, GemmDefault, + 2, 256, + 256, 256, + 32, 8, 4, + 32, 32, + 4, 4, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::LoopScheduler::Default, ck::PipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad370f570efd98e90c2bd53fe7522e5ee249586a --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 64, 8, 2, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 2, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 8361576299c3da3727cc467387e9717fceb93ce3..5afb3d1554d39c283f3efd9c6de98f4e28b4208d 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -41,6 +41,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl BElementOp, CElementOp>; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c75a44d219c9155be3d31e3638b156b7f344835 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::f8_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1dec165abdcc3cf28c7ff3464073a6e8efa75283 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::bf8_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp8_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da891267b233f5940c7d811fcef28f1f4abe0b9e --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 128, 16, 16, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index f6238c7aa5040d8e409a139fe8c144835282cdc9..fb4f383fae1e96a555a2c83f0b842aaec588068c 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index cc03200b9d153e8e22e13e6c77a52ad76ed79174..3237f1a61cc59525de788bc4a84f9e3d2911b974 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62037f774078890ab655d6abcb2c446d92c0d5af --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75971bdecf3cfd204f9494d725d753d902eba141 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +using F32 = float; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp index 7d433b6145017c3aa19f4ae5a9e4bf31101b1025..5a02457dafd1e021b2c0fa71bd1498c891135304 100644 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,6 +44,17 @@ using DeviceGemmInstance = DeviceGemmStreamK; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index b0f963fee51c20b089811ba2b6396b9d7125d8f6..d8672f6a0cc7e3a632171c81d9509ad3e39349e1 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,6 +37,17 @@ using DeviceGemmInstance = DeviceGemmInstance; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 7be2539d903974cd676e92ddb3d692a997abf1cf..bafec3f358037cc9c4d3758e8395556e5eeb4ce8 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -1,10 +1,92 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 2e-1; + } + else if constexpr(std::is_same_v) + { + return 2e-1; + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 2e-1; + } + else if constexpr(std::is_same_v) + { + return 2e-1; + } + else + { + return 1e-3; + } +} + template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -34,21 +116,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) }; auto f_get_default_stride = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(stride == 0) + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) { - // give a chance if stride is zero, return a default packed stride + // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { - return col; + return static_cast(col); } else { - return row; + return static_cast(row); } } else - return stride; + return static_cast(stride); }; StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); @@ -68,13 +150,30 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; - default: + case 2: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + case 3: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -95,6 +194,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); @@ -207,6 +308,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } #endif } + else + { + // When the Problem Type and Problem Size does not fit. + + std::cerr << gemm.GetTypeString() << ": the instance does not support the problem config." + << std::endl; + return true; + } std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = @@ -219,14 +328,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; - if(config.do_verification) + bool pass = true; + + if((config.do_verification == 1) || (config.do_verification == 3)) { + // CPU verification auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + std::cout << "Running verification on CPU." << std::endl; ref_invoker.Run(ref_argument); #ifdef BUILD_INT4_EXAMPLE @@ -240,11 +353,45 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); #endif } - return true; + if((config.do_verification == 2) || (config.do_verification == 3)) + { + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + return pass == true; } bool run_gemm_example(int argc, char* argv[]) diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc new file mode 100644 index 0000000000000000000000000000000000000000..8ed8b81bec13c8d1432077589c83d71866387e3c --- /dev/null +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto Grid_size = problem_size.Grid_size; + auto Streamk_sel = problem_size.Streamk_sel; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) { + if(streamk_sel == -1) + { + return static_cast(4); + } + else + return static_cast(streamk_sel); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Streamk_sel = f_get_default_streamk_policy(Streamk_sel); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2_Streamk_Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + Streamk_sel, + Grid_size, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if((config.do_verification == 1) || (config.do_verification == 3)) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_universal_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK_universal problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc new file mode 100644 index 0000000000000000000000000000000000000000..71524fdecf1306938bc381e84ac873fc11142ede --- /dev/null +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if((config.do_verification == 1) || (config.do_verification == 3)) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 9989e45e0c275189c592414a2fc8e2a28b2fc511..2c20b96eee346686bcebfec2499fb162e5f13825 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,19 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() +add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/02_gemm_bilinear/README.md b/example/02_gemm_bilinear/README.md index 9eb87e1e3479d72497ec72956b1de649b0ff735f..a407ce24f7bf72a2c19dbf1a83c455ee955baa2c 100644 --- a/example/02_gemm_bilinear/README.md +++ b/example/02_gemm_bilinear/README.md @@ -9,20 +9,3 @@ #arg11 to 12: alpha, beta ./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 ``` -Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c0_grid_desc_m_n_{ 3840, 4096} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s -error: 0 -max_diff: 0, 558.5, 558.5 -``` diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 877792d7409f48ee562cded6629f7e1bbde650d9..18731e810e14db2b81a85264f19718a997e1cfa0 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -65,48 +66,49 @@ using CDEElementOp = AlphaBetaAdd; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstance = - ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 256, - 128, - 256, - 8, - 8, - 16, - 16, - 4, - 4, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, - 1, - S<1, 32, 1, 8>, - 8>; +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; int main(int argc, char* argv[]) { @@ -174,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; @@ -264,7 +274,7 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + << device_op.GetTypeString() << std::endl; e_device_buf.FromDevice(e_m_n_device_result.mData.data()); diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..87812369bd1f1523cee5672f3f14624a94cc8113 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& e, const std::int32_t& c, const std::int8_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + int alpha_; + int beta_; +}; + +template +using S = ck::Sequence; + +using I8 = std::int8_t; +using I32 = std::int32_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DDataType = I8; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + int alpha = 1; + int beta = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index a247a052cba40410e8879a57a0e3c35a57b5645f..35c54abac03094f24187df2503aa02b6812c20f3 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,10 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 15ec62c89fcd39e24e7bb1ae5f8e7d366f2138d6..be47665a262ec6619816c249bdfdb96ba3c8ae16 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,28 +1,27 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +add_custom_target(example_gemm_add_add_fastgelu_xdl) +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) + +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) + +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) + set(target 1) endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - endif() - set(target 1) - endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/04_gemm_add_add_fastgelu/README.md b/example/04_gemm_add_add_fastgelu/README.md index 08a55fb9a37f9f7823ae8882d4d36e9bd96ee6fd..7b0d003e59cb6d3c2e53a064b1845f7ab3c5c98a 100644 --- a/example/04_gemm_add_add_fastgelu/README.md +++ b/example/04_gemm_add_add_fastgelu/README.md @@ -8,16 +8,3 @@ #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" ./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} -d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> -``` diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index f206bbeb411bb0de0144015d70a892b5587d6f0e..1d0b0f7861858f44dbd692e52f206e157a1ebe6a 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc index cb3147bcd718c115e47163af58e813bd2cf39bfe..cb0271c81f1d3feac72004da7e37a288fc97d463 100644 --- a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc +++ b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc @@ -105,7 +105,8 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC if(!device_op.IsSupportedArgument(argument)) { - throw std::runtime_error("wrong! this device_op instance does not support this problem"); + std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl; + return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 1af1e6c8582f6f0f4ccf3902dc1dd0c1d5bd02d5..8a295d14c4ec5f45aba77d31cd98ff78cecb6fb4 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,35 +1,13 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - endif() - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - endif() - set(target 1) - endif() -endforeach() - -if(DL_KERNELS) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) - endif() -endif() +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8_fp8 convnd_fwd_xdl_bf8_fp8.cpp) +add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) +add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) +add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/09_convnd_fwd/README.md b/example/09_convnd_fwd/README.md index 9ab5fee549d44aefd22590d66681127a2816d9f9..22f90ea29af28a1e7b38030971ecf49a8b99b849 100644 --- a/example/09_convnd_fwd/README.md +++ b/example/09_convnd_fwd/README.md @@ -16,17 +16,3 @@ # , (ie RightPy, RightPx for 2D) ./bin/example_convnd_fwd_xdl 0 1 100 ``` - -Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) -``` -input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -arg.a_grid_desc_k0_m_k1_{432, 165888, 4} -arg.b_grid_desc_k0_n_k1_{432, 256, 4} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 100 times... -Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s -``` diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 109b8f9ee34096601f0dd485db6840e0674c1937..b0fd6a382a244afd55872db4a51e393254968151 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -27,6 +27,88 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + template (), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index 74cf91d16017bd24096dac6348e5a8815856a51a..b6bb03e1e58d04f1804387be42ee58e5fe7800a8 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fc9e7b5dd781ba76befdd0fe4576ae6ec52eba2 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using ComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9eba00993a292ddfdc0539ee9a9ff20ab5f0c4b0 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using AComputeType = ck::bf8_t; +using BComputeType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeType, + BComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index f6d69bafd48a4b8bff968c2fb8600793b1310294..064a971478f2f8445de318122f4a346834f70a0c 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..346ab8d953097f67295ff3a11ebca26bd54c338f --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; +using ComputeType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index 6c3171f6157ac4a83a488af12089a97accfd2c7c..36517e569dd5a0574190a01c317fec10a718badf 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp index 9977a496d229876c083e0afe99e17ff0fa5fca05..659283c3f05b8a23c639eb1420a9f67f54d788d6 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef130148bca174f8ecf2b47d90c5f3540376de09 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using ComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeDataType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53a12377c5f1f0fad7dbeeaf89f7a2171a14e42f --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeType, + BComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index bf084b3cc0b6ceab326b2b21e963e0ce2d4bdcad..0180e6e71890d8db2eb2f6ac6423003ed95bbb33 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index e7d941ae6b3dddee9f8dfa357cf3426045d0e828..ef8bea185034f402db465903316c9eb1c6fb114a 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,28 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() \ No newline at end of file +add_custom_target(example_convnd_fwd_reduce_xdl) +add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) + +add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) + +add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + +add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 137b0d1ff0fb0a7cb7159d52f54b4cd3496629ae..036f288d0a420f55f28fa5aa0b51d4d9d43434d6 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector::RLayout; struct ExecutionConfig final { bool do_verification = true; - int init_method = 1; + int init_method = 2; bool time_kernel = false; }; @@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc, inline HostTensorDescriptor make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size) { - std::vector dimensions{problem_size.G_, problem_size.N_}; + std::vector dimensions{problem_size.G_, problem_size.N_}; ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions)); diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp index bf7127502faac94e6858d37c389f316b7ead35a6..ce8ff0b49a9b1d72ab6fca2914db680ec9497c8c 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #define BUILD_INT4_EXAMPLE @@ -24,3 +22,4 @@ using RsDataType = ck::Tuple; #include "run_convnd_fwd_max_example.inc" int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } +#endif diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index cebfeb51d63eac30243d8e3b0468b821f6fc1eb3..d61aee81a45836b476d83202946d0087327fdc1c 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, Tensor conv_output_device(conv_output_g_n_k_wos_desc); Tensor r0_device(r0_desc); + std::cout << "input: " << conv_input.mDesc << std::endl; + std::cout << "weight: " << conv_weight.mDesc << std::endl; + std::cout << "output: " << conv_output_device.mDesc << std::endl; + std::cout << "reduction: " << r0_device.mDesc << std::endl << std::endl; + switch(config.init_method) { case 0: break; case 1: ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); - ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_weight); + ck::utils::FillUniformDistributionIntegerValue{-1, 1}(conv_weight); + break; + case 2: + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); + ck::utils::FillUniformDistribution{-1, 1}(conv_weight); break; default: - ck::utils::FillUniformDistribution{-5, 5}(conv_input); - ck::utils::FillUniformDistribution{-5, 5}(conv_weight); + ck::utils::FillUniformDistribution{-8, 7}(conv_input); + ck::utils::FillUniformDistribution{-1, 1}(conv_weight); } DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); @@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, return false; } + // XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0. + r0_device_buf.SetValue(ck::NumericLimits::Lowest()); + const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - const std::size_t flop = problem_size.GetFlops(); - const std::size_t num_btype = problem_size.GetByte(); + if(config.time_kernel) + { + const std::size_t flop = problem_size.GetFlops(); + const std::size_t num_btype = problem_size.GetByte(); - const float tflops = static_cast(flop) / 1.E9 / avg_time; - const float gb_per_sec = num_btype / 1.E6 / avg_time; - std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << conv.GetTypeString() << std::endl; + const float tflops = static_cast(flop) / 1.E9 / avg_time; + const float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv.GetTypeString() << std::endl; + } + else + { + std::cout << "FINISHED: " << conv.GetTypeString() << std::endl; + } if(config.do_verification) { @@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, BElementOp{}, PassThrough{}); + std::cout << "\nRunning verification on CPU." << std::endl; ref_invoker.Run(ref_argument); Tensor r0_host(r0_device.mDesc); @@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, conv_output_device_buf.FromDevice(conv_output_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data()); - return ck::utils::check_err(conv_output_device, - conv_output_host, - "Error: incorrect results! (Matrix E)", - 1e-5f, - 1e-4f) && - ck::utils::check_err( - r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); + auto pass = ck::utils::check_err(conv_output_device, + conv_output_host, + "Error: incorrect results! (Matrix E)", + 1e-3f, + 1e-3f); + pass = + pass && ck::utils::check_err( + r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f); + if(pass) + std::cout << "Verification on CPU: PASS" << std::endl; + + return pass; } return true; diff --git a/example/12_reduce/CMakeLists.txt b/example/12_reduce/CMakeLists.txt index 6e58ed933807d66a89967b6e278e084ccbe77d4d..03381a449f099a9ad2e54ed59cd0889bdde7313c 100644 --- a/example/12_reduce/CMakeLists.txt +++ b/example/12_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) +add_example_executable(example_reduce_threadwise_multi_d reduce_threadwise_multi_d.cpp) add_example_executable(example_reduce_multiblock_atomic_add reduce_multiblock_atomic_add.cpp) add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp) diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md index 76d28527bb8aa521611e54d172b427aabf9fb998..bcffa684c8f2c5c34efbf4710656b0124dfa8a42 100644 --- a/example/12_reduce/README.md +++ b/example/12_reduce/README.md @@ -2,7 +2,7 @@ ## Run ```example_reduce_blockwise``` ```bash -# -D : input 3d/4d/5d tensor lengths +# -D : input 3D/4D/5D tensor lengths # -R : reduce dimension ids # -v : verification (0=no, 1=yes) #arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4) @@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr ## Run ```example_reduce_multiblock_atomic_add``` ```bash -# -D : input 3d/4d/5d tensor lengths +# -D : input 3D/4D/5D tensor lengths # -R : reduce dimension ids # -v : verification (0=no, 1=yes) #arg1: data type (0: fp32, 1: fp64) diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 9a736d4cfac9d1062c568e795592711e30d18586..309100cddece844faa7b04426be90a70756b607a 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -255,34 +255,61 @@ int main(int argc, char* argv[]) else { // for testing half_t + pass = + pass && reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); // for testing float + pass = + pass && reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); + pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); // for testing double + pass = + pass && reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); + pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); // for testing bhalf_t + pass = pass && + reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); + pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); // for testing int8_t + pass = + pass && reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); + pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 // for testing int4_t using AVG operation + pass = + pass && reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); + pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); // for testing int4_t using MAX operation + pass = + pass && reduce_blockwise_test( + true, 2, true, {3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, {0, 1, 2}, 1.0f, 0.0f); + pass = pass && reduce_blockwise_test( true, 2, true, {16, 64, 32, 960}, {0, 1, 2}, 1.0f, 0.0f); #endif diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index 7f8394a7301cc3da8271e1e8c41482d0b06a2859..f1225d86e498a191830690ae8bcec6b0eaada9bb 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -316,7 +316,17 @@ int reduce_blockwise_impl(bool do_verification, auto invoker_ptr = reduce.MakeInvokerPointer(); - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + int log_level = 0, cold_niters = 5, nrepeat = 50; + if(beta != 0.0f) + { + std::cerr << "Warning: With beta != 0.0f there must be only one repeat for correct results " + "since out memory is being overwritten." + << std::endl; + cold_niters = 0; + nrepeat = 1; + } + float avg_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_level, cold_niters, nrepeat}); std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + invariant_total_length * sizeof(InOutDataType); diff --git a/example/12_reduce/reduce_example_common.hpp b/example/12_reduce/reduce_example_common.hpp index 5f9a48804a7268cef5a4a065d0c875565bdb3e9e..08cd6e7ff9b3c8f78d3303c676a8244613fd9368 100644 --- a/example/12_reduce/reduce_example_common.hpp +++ b/example/12_reduce/reduce_example_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -38,7 +38,8 @@ struct ReduceShape static constexpr ck::index_t NumReduceDim_ = NumReduceDim; }; -using reduce_shape_instances = std::tuple, +using reduce_shape_instances = std::tuple, + ReduceShape<3, 1>, ReduceShape<3, 2>, ReduceShape<4, 1>, ReduceShape<4, 2>, diff --git a/example/12_reduce/reduce_threadwise_multi_d.cpp b/example/12_reduce/reduce_threadwise_multi_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..174c312813ac087348d8f0eae9aa5aa63f0f9d34 --- /dev/null +++ b/example/12_reduce/reduce_threadwise_multi_d.cpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/utility/reduction_enums.hpp" +#include "reduce_threadwise_multi_d_impl.hpp" +#include "reduce_example_common.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {16, 64, 32, 16}; + std::vector reduceDims = {0}; + std::vector scales = {1.0f, 0.0f}; + + bool do_verification = true; + int data_type = 1; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--reduceDims or -R, comma separated list of to-reduce dimensions" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4)" + << std::endl; + std::cout << "Arg2 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg3 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:R:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'R': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + reduceDims = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 3 > argc) + { + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + }; + + data_type = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +template +bool reduce_threadwise_multi_d_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector& inLengths, + const std::vector& reduceDims, + float alpha, + float beta) +{ + bool matched = false; + int result = 0; + + const auto tuple_object = reduce_shape_instances{}; + + static_for<0, std::tuple_size::value, 1>{}([&](auto i) { + if(matched) + return; + + using ShapeType = remove_cvref_t(tuple_object))>; + + if(ShapeType::Rank_ != inLengths.size() || ShapeType::NumReduceDim_ != reduceDims.size()) + return; + + std::array arrReduceDims; + + ck::ranges::copy(reduceDims, arrReduceDims.begin()); + + result = reduce_threadwise_multi_d_impl( + do_verification, init_method, time_kernel, inLengths, arrReduceDims, alpha, beta); + + matched = true; + }); + + return (result == 0) ? true : false; +}; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AVG; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + SimpleAppArgs arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + pass = reduce_threadwise_multi_d_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + else if(arg.data_type == 1) + { + pass = + reduce_threadwise_multi_d_test( + arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inLengths, + arg.reduceDims, + arg.scales[0], + arg.scales[1]); + } + } + else + { + // for testing half_t + pass = pass && reduce_threadwise_multi_d_test( + true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f); + + // for testing float + pass = pass && + reduce_threadwise_multi_d_test( + true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f); + + // for testing bhalf_t + pass = pass && reduce_threadwise_multi_d_test( + true, 2, true, {16, 64, 32, 960}, {0}, 1.0f, 0.0f); + } + + return (pass ? 0 : 1); +}; diff --git a/example/12_reduce/reduce_threadwise_multi_d_impl.hpp b/example/12_reduce/reduce_threadwise_multi_d_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3e5006a644dd66d4976c79fa1daa8b0dc9b481d9 --- /dev/null +++ b/example/12_reduce/reduce_threadwise_multi_d_impl.hpp @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" + +#include "reduce_example_common.hpp" + +template +int reduce_threadwise_multi_d_impl(bool do_verification, + int init_method, + bool time_kernel, + const std::vector& inLengths, + const std::array& reduceDims, + float alpha, + float beta) + +{ + using namespace ck; + using namespace ck::tensor_operation::device; + + constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + constexpr bool op_support_indices = + (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || + ReduceOpId == ReduceTensorOp::AMAX); + + constexpr bool invalid_reduce_1 = OutputIndex && !op_support_indices; + + // 1) If InOutDataType is half_t, must use half_t as AccDataType for indexable reduction + // operations 2) If InOutDataType is half_t, must use float as AccDataType for non-indexable + // reduction operations + constexpr bool invalid_reduce_2 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InOutDataType is float, must use float as AccDataType for indexable reduction + // operations + constexpr bool invalid_reduce_3 = + std::is_same::value && + (op_support_indices && !std::is_same::value); + + // 1) If InOutDataType is int8_t or int4_t, must use int8_t as AccDataType for indexable + // reduction operations 2) If InOutDataType is int8_t or int4_t, must use int32_t as AccDataType + // for non-indexable reduction operations + constexpr bool invalid_reduce_4 = + std::is_same::value && + ((!op_support_indices && !std::is_same::value) || + (op_support_indices && !std::is_same::value)); + + // 1) If InOutDataType is int8_t or int4_t, the supported operation must be either indexable + // operations or ADD/AVG + constexpr bool invalid_reduce_5 = std::is_same::value && + (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD && + ReduceOpId != ReduceTensorOp::AVG); + + // 1) If InOutDataType is bhalf_t, must use float as AccDataType for all reduction operations + constexpr bool invalid_reduce_6 = + std::is_same::value && !std::is_same::value; + + constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || + invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); + + if constexpr(invalid_reduce) + { + std::cerr << "The reduction setting is invalid, exiting!" << std::endl; + return (-1); + }; + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Add = tensor_operation::element_wise::Add; + + using ReduceOperation = typename reduce_binary_operator::opType; + using InElementwiseOperation = PassThrough; + using OutElementwiseOperation = Add; + + using InOutDataTypeInDevice = InOutDataType; + + using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceThreadWiseMultiD, + AccDataType, + InOutDataTypeInDevice, + Rank, + NumReduceDim, + ReduceOperation, + InElementwiseOperation, + OutElementwiseOperation, + 256, // BlockSize + 4, // MThreadSliceSize + 1, // KThreadSliceSize + 0, // InSrcVectorDim + 1, // InSrceVectorSize + 1, + Sequence<1>>; // OutDstVectorSize + + Tensor in(inLengths); + + std::vector outLengths; + + auto invariantDims = get_invariant_dims(reduceDims); + + if(invariantDims.empty()) + outLengths.push_back(1); + else + for(auto dim : invariantDims) + outLengths.push_back(inLengths[dim]); + + Tensor out_ref(outLengths); + Tensor out(outLengths); + + Tensor d0(outLengths); + + Tensor out_indices_ref(outLengths); + Tensor out_indices(outLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verification) + { + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + d0.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + d0.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + d0.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InOutDataTypeInDevice) * in.mDesc.GetElementSpaceSize()); + DeviceMem d0_dev(sizeof(InOutDataTypeInDevice) * d0.mDesc.GetElementSpaceSize()); + DeviceMem out_dev(sizeof(InOutDataTypeInDevice) * out.mDesc.GetElementSpaceSize()); + + in_dev.ToDevice(in.mData.data()); + d0_dev.ToDevice(d0.mData.data()); + + if(beta != 0.0f) + { + out_dev.ToDevice(out.mData.data()); + }; + + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; + + DeviceMem out_index_dev(indicesSizeInBytes); + + InElementwiseOperation in_elementwise_op; + OutElementwiseOperation out_elementwise_op; + + std::array arrInLengths; + std::array arrInStrides; + + std::array arrOutLengths; + std::array arrOutStrides; + + ck::ranges::copy(inLengths, arrInLengths.begin()); + ck::ranges::copy(inStrides, arrInStrides.begin()); + + ck::ranges::copy(outLengths, arrOutLengths.begin()); + ck::ranges::copy(outStrides, arrOutStrides.begin()); + + if(do_verification) + { + using ReferenceReduceInstance = + ck::tensor_operation::host::ReferenceReduce; + + auto reduce_ref = ReferenceReduceInstance{}; + + auto argument_ptr_ref = reduce_ref.MakeArgumentPointer(arrInLengths, + arrInStrides, + arrOutLengths, + arrOutStrides, + reduceDims, + static_cast(alpha), + static_cast(beta), + in.mData.data(), + nullptr, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + PassThrough{}); + + if(!reduce_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters not supported by the reduce reference, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = reduce_ref.MakeInvokerPointer(); + + invoker_ptr_ref->Run(argument_ptr_ref.get()); + + for(std::size_t i = 0; i < out_ref.GetElementSize(); i++) + out_elementwise_op(out_ref.mData[i], out_ref.mData[i], d0.mData[i]); + }; + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(arrInLengths, + arrInStrides, + {arrOutLengths}, + {arrOutStrides}, + arrOutLengths, + arrOutStrides, + reduceDims, + in_dev.GetDeviceBuffer(), + {d0_dev.GetDeviceBuffer()}, + out_dev.GetDeviceBuffer(), + in_elementwise_op, + out_elementwise_op); + + if(!reduce.IsSupportedArgument(argument_ptr.get())) + { + std::cerr << "The runtime parameters not supported by the DeviceReduce instance, exiting!" + << std::endl; + + return (-2); + }; + + std::string reduce_name = reduce.GetTypeString(); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name + << std::endl; + + bool pass = true; + + if(do_verification) + { + + out_dev.FromDevice(out.mData.data()); + + pass = pass && ck::utils::check_err(out, out_ref); + + if(OutputIndex) + { + out_index_dev.FromDevice(out_indices.mData.data()); + pass = pass && ck::utils::check_err(out_indices, out_indices_ref); + }; + }; + + return (pass ? 0 : 1); +} diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index d0f356757b3845cb1a72708f0cd5c367eb1042fc..e2a923cded8fe2ceecabfaff5a3201817ae3700a 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) -endif() +add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 3b3ad80dd826fef38f8d4ed36d504538ae3e263a..8703fa3ed7bb811be45b17378e7aa4a4837502e7 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,17 +1,3 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -# dlops -if(DL_KERNELS) - add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -endif() - -# xdlops -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) - add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) - set(target 1) - endif() -endforeach() -endif() \ No newline at end of file +add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index dca60b0e739d556ee9c6cb6d463884c8b308866a..20cbc5fdca60cc9ab2e9e0ea22bd665b69871df7 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,26 +1,35 @@ add_custom_target(example_grouped_gemm_xdl) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) - add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) - add_dependencies(example_grouped_gemm_xdl - example_grouped_gemm_xdl_fp16 - example_grouped_gemm_multiple_d_dl_fp16 - example_grouped_gemm_xdl_splitk_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -endif() +add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) + +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16) + +add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) + +add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) + +add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8) + +add_example_executable(example_grouped_gemm_multiple_d_xdl_fp16 grouped_gemm_multiple_d_xdl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_xdl_fp16) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) + add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() diff --git a/example/15_grouped_gemm/README.md b/example/15_grouped_gemm/README.md index c83b23e08cc7b923ce22316e0c882886d7437ef3..a2afe0f4b9ed003de276d6be71f9efab64525393 100644 --- a/example/15_grouped_gemm/README.md +++ b/example/15_grouped_gemm/README.md @@ -7,19 +7,3 @@ #arg3: run kernel # of times (>1) ./bin/example_grouped_gemm_xdl_fp16 0 1 5 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} -gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} -gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} -gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} -group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} -group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} -group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} -group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} -launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> -``` diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ecff7b4713f1ed5ff6402e9ab8ae7198f900e1fe --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDMatrices = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 128; + bool time_kernel = true; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDMatrices>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDMatrices> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + gemm_descs.push_back({problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + problem_size.stride_Ds[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + gemm.SetKBatchSize(argument, config.k_batch); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = argument.gemm_kernel_args_[i].karg_; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{})); + + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(), + c_device_result_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 11) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg10: k_batch (> 0)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[10]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..965a0e7e37836c06e3aceb29cd19976444c99707 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -0,0 +1,404 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDs = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDXdlCShuffleTileLoop + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using GemmDesc = ck::tensor_operation::device::GemmDesc; + + // GEMM shape + std::vector gemm_descs; + std::vector ggemm_kargs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + ggemm_kargs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDs>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDs> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + for(int j = 0; j < NumDs; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDs; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + // The device op does not have to know M problem size at lunch time. + gemm_descs.push_back({0, + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}}); + ggemm_kargs.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {problem_size.stride_Cs[i], problem_size.stride_Cs[i]}, + problem_size.stride_Cs[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + ggemm_kargs.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = ggemm_kargs[i]; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{})); + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data()); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 10) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDs; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp similarity index 100% rename from example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a193fc39ba637cbd41df4743e0157ca40c38407b --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, d0_tensors_device, + c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {0}}); + + grouped_gemm_kernel_args_.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a2bcfb33edd5d87d692a5aedd99ae1b0edfac2a --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(128 + rand() % 128); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(1024); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a63a29843aa319a38b3673db702952e06d8d851 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 743ab96be6291e617cb012957865a6897bd61a91..9f8f6cb1e49d7a84f8a1bb596e568131874cf63f 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,13 +66,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 00786d34a3af79a6d896e58b4fa12d9bbf8a7aa3..1e12c16f30fbd60ede27a296ff5f7517ca61a478 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,47 +1,41 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - endif() - - add_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - set(target 1) - endif() -endforeach() +add_custom_target(example_gemm_reduce_xdl) +add_custom_target(example_gemm_reduce_xdl_max) +add_custom_target(example_gemm_reduce_xdl_mean_meansquare) +add_custom_target(example_gemm_add_add_mean_meansquare_xdl) + +add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + +add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + +add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) + +add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + +add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + +add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + +add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + +add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + +add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + +add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) +endif() diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index 2f6533d4481a2f339417a19f0f290f7fd306ee9c..a46eaa48169e534907098692e547dbc5bbebab11 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -198,7 +198,7 @@ int main() throw std::runtime_error("wrong! this device_op instance does not support this problem"); } - // init reducetion buffer to 0 + // init reduction buffer to 0 r0_device_buf.SetZero(); r1_device_buf.SetZero(); diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index e187bd4337d53e9c1e7b2ecedf023169386825e3..39f9fb8ec0696b1d3d7ec1f09f4eeb44d8b5accd 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,15 +1,9 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) - set(target 1) - endif() -endforeach() - if(DL_KERNELS) - add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) - target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) - endif() +add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) +endif() + +add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) endif() diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index 4a9d16c5c303e43989c1f32e51c2cbce6f279e5d..d219df024533a6277f192892a895db1e8b05fd11 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification, // reset input to zero in_device_buf.SetZero(); + std::vector input_spatial_lengths_i32(NDimSpatial); + std::vector filter_spatial_lengths_i32(NDimSpatial); + std::vector output_spatial_lengths_i32(NDimSpatial); + std::vector conv_filter_strides_i32(NDimSpatial); + std::vector conv_filter_dilations_i32(NDimSpatial); + std::vector input_left_pads_i32(NDimSpatial); + std::vector input_right_pads_i32(NDimSpatial); + + for(ck::index_t d = 0; d < NDimSpatial; d++) + { + input_spatial_lengths_i32[d] = + static_cast(conv_param.input_spatial_lengths_[d]); + filter_spatial_lengths_i32[d] = + static_cast(conv_param.filter_spatial_lengths_[d]); + output_spatial_lengths_i32[d] = + static_cast(conv_param.GetOutputSpatialLengths()[d]); + conv_filter_strides_i32[d] = static_cast(conv_param.conv_filter_strides_[d]); + conv_filter_dilations_i32[d] = + static_cast(conv_param.conv_filter_dilations_[d]); + input_left_pads_i32[d] = static_cast(conv_param.input_left_pads_[d]); + input_right_pads_i32[d] = static_cast(conv_param.input_right_pads_[d]); + } + // do GEMM auto conv = DeviceConvNdBwdDataInstance{}; auto invoker = conv.MakeInvoker(); @@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification, conv.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.GetOutputSpatialLengths(), - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, + static_cast(conv_param.N_), + static_cast(conv_param.K_), + static_cast(conv_param.C_), + input_spatial_lengths_i32, + filter_spatial_lengths_i32, + output_spatial_lengths_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, in_element_op, wei_element_op, out_element_op); diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index a1bb398af0640001f7fba2192e47ffc03e5ce3a0..94ed129dc03664043b1a0295f9f1dcfb38a77002 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index e363dc5c12dd84a36a67e969b06d29a179679f94..62295c57eb3040cbf16547ec277f49a5e44e9a44 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -272,15 +272,14 @@ int main(int argc, char* argv[]) { for(int m = 0; m < M; ++m) { - auto reduce0_acc = reduce0_op.GetIdentityValue(); - auto reduce1_acc = reduce1_op.GetIdentityValue(); - + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + ReduceAccDataType d0_val = 0; + ReduceAccDataType d1_val = 0; for(int n = 0; n < N; ++n) { auto c_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); - ReduceAccDataType d0_val; - ReduceAccDataType d1_val; UnaryIdenticElementOp{}(d0_val, c_val); UnarySquareElementOp{}(d1_val, c_val); diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp index 24c8d82f674d97afae8cb1b3dc0274c1c13daf35..1e7899b35df95c7977af27f4286122efa256ace6 100644 --- a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -27,7 +27,12 @@ using DeviceElementwiseAddInstance = ck::Tuple, Add, 2, + 64, + 64, + 64, 8, + 8, + ck::Sequence<1, 0>, ck::Sequence<8, 8>, ck::Sequence<8>>; diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp index 3c04c561403d0ecfb0819aa569bdf24511053442..5f321e62840280d4dcf12ecf010327b38c6f12c3 100644 --- a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -27,9 +27,14 @@ using DeviceElementwiseAddInstance = ck::Tuple, Add, 3, - 8, - ck::Sequence<1, 8>, - ck::Sequence<8>>; + 64, + 16, + 16, + 2, + 2, + ck::Sequence<1, 0>, + ck::Sequence<1, 2>, + ck::Sequence<2>>; template void host_broadcast3D_am_bmnk(HostTensorC& C, diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index 1ac09641a1e20a6be8bf44f4f1ff0f3be045dedb..90e2d28d956df6b516a54630983f3a77b5ca96ae 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -25,9 +25,14 @@ using DeviceElementwiseAddInstance = ck::Tuple, Add, 1, - 8, - ck::Sequence<8, 8>, - ck::Sequence<8>>; + 64, + 16, + 16, + 2, + 2, + ck::Sequence<1, 0>, + ck::Sequence<2, 2>, + ck::Sequence<2>>; template void host_elementwise1D( diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index e571aa8468008a5f4e7097eee74763a583faf9af..797521dcb4a577cc2511d9798dac030640eafda5 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -27,9 +27,14 @@ using DeviceElementwiseAddInstance = ck::Tuple, Add, 4, - 8, - ck::Sequence<8, 8>, - ck::Sequence<8>>; + 64, + 2, + 128, + 2, + 2, + ck::Sequence<1, 0>, + ck::Sequence<2, 2>, + ck::Sequence<2>>; template void host_elementwise4D(HostTensorC& C, diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index d649567ed2c46ce05073d731a435fbf65c034b73..497ea19e11f6cd374f50587a619819d9fcd4e5e3 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,24 +1,15 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - endif() - set(target 1) - endif() -endforeach() +add_custom_target(example_grouped_conv_bwd_weight) +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - add_custom_target(example_grouped_conv_bwd_weight_dl) - add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) - add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) - endif() -endif() \ No newline at end of file +add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) + +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + +add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) + +add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 15727495f0f1f8dfedcb24a0ffa02d3d7aea67e0..e0034bf7ebd4bca22cb429885e1e38a4addd3397 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -23,6 +23,8 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; template using S = ck::Sequence; @@ -40,25 +42,21 @@ struct CommonLayoutSetting using OutputLayout = OutputLay; }; -template -struct CommonLayoutSettingSelector; - namespace ctl = ck::tensor_layout::convolution; - -template <> -struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<2> final - : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<3> final - : CommonLayoutSetting +template +struct CommonLayoutSettingSelector + : CommonLayoutSetting>, + ck::tuple_element_t>, + ck::tuple_element_t>> { }; @@ -78,10 +76,10 @@ struct ExecutionConfig final bool time_kernel = false; }; -#define DefaultConvParam \ - ck::utils::conv::ConvParam \ - { \ - 2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \ } inline void print_help_msg() diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index 375c309e1c93d30a6d7167d524613c9386510551..44ba3d8e02498ce72c6f0f4f8e9bdcaaf9ef1760 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -3,7 +3,7 @@ #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" using InDataType = F16; using WeiDataType = F16; @@ -15,45 +15,84 @@ using WeiElementOp = PassThrough; using OutElementOp = PassThrough; template -using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< - NDimSpatial, // NDimSpatial - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvBwdWeightDefault, // ConvBackwardWeightSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // K0PerBlock - 2, // K1 - 4, // M1PerThread - 4, // N1PerThread - 1, // KPerThread - S<8, 2>, // M1N1ThreadClusterM1Xs - S<8, 2>, // M1N1ThreadClusterN1Xs - S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 - S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder - S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder - S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 - S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 - S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder - S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder - S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - 4>; // CThreadTransferDstScalarPerVector +using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + NDimSpatial, // NDimSpatial + ck::tuple_element_t>, // InLayout + ck::tuple_element_t>, // WeiLayout + ck::tuple_element_t>, // OutLayout + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 2, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder + S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder + S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c2950b17e802b9418c1614eb8c8c34168929c3b --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + NDimSpatial, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 16, // MPerWMMA + 16, // NPerWMMA + 4, // MRepeat + 2, // NRepeat + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // ABlockTransferSrcAccessOrder + 1, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 4, + 2, + S<1, 32, 1, 8>, + 1>; + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 31e277b5c709a44a3495348c2e9419dafeb7611c..80b97249304f1cb8403246785335d4c5c98784b1 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -67,6 +67,34 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 69c831cc54eef84fc0b7ea46f9961573bf3efff4..5c1c97d615df22a41a8331c800e499a1b433b047 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -66,6 +66,34 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fca6cdd7ab4248d25cfd46fe8d794a2e18e94a9a --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; +using ComputeTypeA = BF8; +using ComputeTypeB = F8; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CBlockTransferScalarPerVector_NWaveNPerXdl + ComputeTypeA, // ComputeTypeA + ComputeTypeB>; // ComputeTypeB + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 29ce0324abecbbfbc53864871fa6cf558339bbc2..f320c0305bf29dad3a6ec478057de017216c5039 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,33 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -template -using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; - template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) { - ck::index_t split_k; - // Set split_k = 2 for xdl op, split_k = 1 for dl - // Dl op doesn't support split_k > 1 - // TODO: Add Dl op split_k > 1 support - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) - { - split_k = 2; - } - else - { - split_k = 1; - } + // Dl and WMMA ops don't support split_k > 1 + constexpr ck::index_t split_k = 1; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< @@ -58,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); @@ -125,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return true; } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = conv_param.GetFlops(); - std::size_t num_btype = conv_param.GetByte(); - - float tflops = static_cast(flop) / 1.E9 / avg_time; - - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl - << "DeviceOp: " << conv.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); if(config.do_verification) { @@ -151,7 +119,10 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, conv_param.input_right_pads_, InElementOp{}, WeiElementOp{}, - OutElementOp{}); + OutElementOp{}, + {}, + {}, + {}); ref_invoker.Run(ref_argument); @@ -160,25 +131,18 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); } - return true; -} + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); -bool run_grouped_conv_bwd_weight_example(int argc, char* argv[]) -{ - ExecutionConfig config; - ck::utils::conv::ConvParam conv_param = DefaultConvParam; + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); - if(!parse_cmd_args(argc, argv, config, conv_param)) - { - return false; - } + float tflops = static_cast(flop) / 1.E9 / avg_time; - switch(conv_param.num_dim_spatial_) - { - case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param); - case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param); - case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param); - } + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl + << "DeviceOp: " << conv.GetTypeString() << std::endl; - return false; + return true; } diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 6a6735efd626bc1931fa9f7e6651dceae713da11..2eb7052e1eea21ea652f088813618b15256fca10 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,13 +1,4 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp index 96d04dcb37798fb6db3334bd2d802e435fdc93fd..5dccb11bbae7db2c72abb1d4719e6f248c70308b 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -9,7 +9,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/device_memory.hpp" @@ -103,9 +103,14 @@ using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwiseI ck::Tuple, // y NormalizeFunctor, 2, - 8, // MPerthread - ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta - ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out) + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 2, // MPerthread + 2, // NPerthread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<2, 1, 1, 2, 2>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta + ck::Sequence<2>>; // scalarPerVector: y(layerNorm_out) auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { return HostTensorDescriptor({len}, {stride}); diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index fc58ca19f8673aa0b2205df21350f7ddf0b92196..6a92e9a2f50c8a5f1b80228307ad946188ea3c75 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor& h_m_n, BetaDataType, HDataType, AccDataType, + AccDataType, HElementOp, 2, 1>; Tensor e_m_n(HostTensorDescriptor{M, N}); Tensor c_m_n(HostTensorDescriptor{M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); auto ref_gemm = ReferenceGemm{}; auto ref_gemm_invoker = ref_gemm.MakeInvoker(); @@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor& h_m_n, auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_argument = ref_layernorm.MakeArgument( - e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); + e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon); ref_layernorm_invoker.Run(ref_layernorm_argument); } diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp index bd1d6932aceb72890dcb0f7c35004cc0dccd93af..168193ad5b599844c725ce2f273bebe468519597 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -9,7 +9,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/device_memory.hpp" @@ -102,9 +102,14 @@ using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwiseI ck::Tuple, // y NormalizeFunctor, 2, - 8, // MPerthread - ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta - ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out) + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 2, // MPerthread + 2, // NPerthread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<2, 1, 1, 2, 2>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta + ck::Sequence<2>>; // scalarPerVector: y(layerNorm_out) auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { return HostTensorDescriptor({len}, {stride}); diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 854f07fda6aad404fda546b40bef1c8b45c766d5..44585b11d04e762f58cc20c1816200857eef0fa6 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,22 +1,18 @@ add_custom_target(example_cgemm_xdl) -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) +add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) + +add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) + add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) -endif() +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) + +add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) + add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) + add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) endif() diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 48a3b58ff53a712492e82ffce5632c08b6937ba2..720af39af645e622cba1897a46fb1f7004516dae 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,21 +1,24 @@ add_custom_target(example_batched_gemm_xdl) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) -endif() + +add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) + +add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) + +add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_xdl_bf16_v3 batched_gemm_xdl_bf16_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16_v3) + +add_example_executable(example_batched_gemm_xdl_fp8_rowwise_v3 batched_gemm_xdl_fp8_rowwise_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp8_rowwise_v3) + +add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) + add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) + add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() diff --git a/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp similarity index 100% rename from example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp rename to example/24_batched_gemm/batched_gemm_xdl_bf16.cpp diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa8b752185c112160434d2e1b9866312dbbbda14 --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmDefault, + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<8>, // CDEShuffleBlockTransferScalarPerVectors + ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler + ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion + >; + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0160b31ce1e6a3444b8bd12d47a96a1cf7d9698 --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + +using ADataType = F8; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmDefault, + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors + ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler + ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion + F8 // ComputeTypeA + >; + +#include "run_batched_gemm_example_rowwise.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 21934add3162ce9508a82656847b056bea8402c6..741512bf008487cff0058816e9d7ce7063bd0d57 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -210,17 +210,9 @@ bool run_batched_gemm_example(int argc, char* argv[]) problem_size.M = 256 * (dis(gen) + 1); problem_size.N = 128 * (dis(gen) + 1); - problem_size.K = 64 * (dis(gen) + 2); + problem_size.K = 128 * (dis(gen) + 2); - problem_size.stride_A = problem_size.K; - problem_size.stride_B = problem_size.K; - problem_size.stride_C = problem_size.N; - - problem_size.batch_stride_A = problem_size.M * problem_size.K; - problem_size.batch_stride_B = problem_size.K * problem_size.N; - problem_size.batch_stride_C = problem_size.M * problem_size.N; - - problem_size.batch_count = 16; + problem_size.batch_count = 2; if(argc == 4) { @@ -228,13 +220,37 @@ bool run_batched_gemm_example(int argc, char* argv[]) config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } + else if(argc == 8) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + problem_size.batch_count = std::stoi(argv[7]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("optinal\n"); + printf("arg4-7: M = %d N = %d K = %d Batch = %d\n", + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.batch_count); exit(0); } + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + return run_batched_gemm(problem_size, config); } diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc new file mode 100644 index 0000000000000000000000000000000000000000..778be8ffd781dcac8c7f8d1018a032f3398546db --- /dev/null +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include + +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t stride_D0 = 0; + ck::index_t stride_D1 = 0; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + ck::index_t batch_stride_D0 = N; + ck::index_t batch_stride_D1 = M; + + ck::index_t batch_count = 16; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + stride_D0, + stride_D1, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D0, + batch_stride_D1, + batch_count] = problem_size; + + // GEMM shape + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor d0_g_m_n( + f_host_tensor_descriptor(batch_count, M, N, stride_D0, batch_stride_D0, D0Layout{})); + Tensor d1_g_m_n( + f_host_tensor_descriptor(batch_count, M, N, stride_D1, batch_stride_D1, D1Layout{})); + Tensor e_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl; + std::cout << "d1_g_m_n: " << d1_g_m_n.mDesc << std::endl; + std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + d0_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + d0_device_buf.ToDevice(d0_g_m_n.mData.data()); + d1_device_buf.ToDevice(d1_g_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = + gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + batch_count, + stride_A, + stride_B, + {stride_D0, stride_D1}, + stride_C, + batch_stride_A, + batch_stride_B, + {batch_stride_D0, batch_stride_D1}, + batch_stride_C, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; + + if(config.do_verification) + { + c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); + + Tensor c_g_m_n({batch_count, M, N}); + + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + Tensor e_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int b = 0; b < batch_count; ++b) + { + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_g_m_n_host_result(b, m, n), + c_g_m_n(b, m, n), + d0_g_m_n(b, m, n), + d1_g_m_n(b, m, n)); + } + } + } + + pass = ck::utils::check_err( + e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c"); + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(EDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass ? 0 : 1; +} + +bool run_batched_gemm_rowwise_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 256 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 128 * (dis(gen) + 2); + + problem_size.batch_count = 2; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 8) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + problem_size.batch_count = std::stoi(argv[7]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("optinal\n"); + printf("arg4-7: M = %d N = %d K = %d Batch = %d\n", + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.batch_count); + exit(0); + } + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.stride_D0 = 0; + problem_size.stride_D1 = 0; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + problem_size.batch_stride_D0 = problem_size.N; + problem_size.batch_stride_D1 = problem_size.M; + + return run_batched_gemm_rowwise(problem_size, config); +} diff --git a/example/25_gemm_bias_e_permute/CMakeLists.txt b/example/25_gemm_bias_e_permute/CMakeLists.txt index eb274b233802401628abe0f030159da75b8f5950..cbc3c007bc22622554fd93f4d8a829c9ab666dc4 100644 --- a/example/25_gemm_bias_e_permute/CMakeLists.txt +++ b/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -1,4 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) - add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) -endif() +add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) +add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 6cab88b13ef5761a16cedcd9d86d7f855fe5b99e..f3d30cea2a13dbe0d37a51833442d3c628d71321 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,8 +1,52 @@ -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) - add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -endif() -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) - add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -endif() +add_custom_target(example_contraction) +add_custom_target(example_contraction_scale) +add_custom_target(example_contraction_bilinear) + +# FP32 +add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) + +add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) + +add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) + +add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) + +add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) + +add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) + +# FP64 +add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) + +add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) + +add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) + +add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) + +# FP16 +add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) + +add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) + +# BF16 +add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) + +add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) + +add_example_dependencies(example_contraction example_contraction_scale) +add_example_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/26_contraction/README.md b/example/26_contraction/README.md index c88d93cf83a411e2499cd0be80e524c42db339c6..acbfa84df108289b374c15ab6843a8cfdd99438f 100644 --- a/example/26_contraction/README.md +++ b/example/26_contraction/README.md @@ -7,14 +7,3 @@ #arg3: time kernel (0=no, 1=yes) ./bin/example_contraction_bilinear_xdl_fp32 1 1 1 ``` - -Result (MI100 @ dynammic freq, 46TFlops peak FP32) -``` -a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} -b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096} -c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} -launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4> -``` diff --git a/example/26_contraction/common_instances.hpp b/example/26_contraction/common_instances.hpp new file mode 100644 index 0000000000000000000000000000000000000000..480ca5a0af67bcf8db871fcd79e795497901c5ab --- /dev/null +++ b/example/26_contraction/common_instances.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using F64 = double; + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Generic instances for fp32, fp16 and bf16 data types. +template +// clang-format off +using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +// Fp64 instances. +template +// clang-format off +using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77ec5669a76e093fd351f82b64d77ee40a147ad4 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DDataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25c4d9eac467f694a61b3f697dd6fd99848353b4 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index 78522160c85a1024ec510197e5ab1069e6f077cf..d440e7394e5da2794f3a959ea769587cb60afc7f 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F32 = float; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F32; using BDataType = F32; @@ -32,6 +13,7 @@ using CShuffleDataType = F32; using DDataType = F32; using DsDataType = ck::Tuple; using EDataType = F32; +using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -41,253 +23,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceKNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -// clang-format on +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; using DeviceOpInstance = DeviceOpInstanceKKNN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float alpha = 1.f; - float beta = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 28) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - d_ms_ns_lengths = {M0, M1, N0, N1}; - d_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; - - alpha = std::stof(argv[26]); - beta = std::stof(argv[27]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); - printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg26 to 27: alpha, beta\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_bilinear_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b748949bd53ce0954e5e62440b74c3d1c5aa235 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = BF16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..334cd615d17e746e4112e54aad8a0107c0f67cea --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 6cceed5bc11ed8eb79568a0397bbf7e1b93fd33d..86ea7e0212d62a5f0e6773bbba0b3bffaad005c7 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F64 = double; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F64; using BDataType = F64; @@ -32,6 +13,7 @@ using CShuffleDataType = F64; using DDataType = F64; using DsDataType = ck::Tuple; using EDataType = F64; +using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -41,253 +23,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceKNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -// clang-format on +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; using DeviceOpInstance = DeviceOpInstanceKKNN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float alpha = 1.f; - float beta = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 28) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - d_ms_ns_lengths = {M0, M1, N0, N1}; - d_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; - - alpha = std::stof(argv[26]); - beta = std::stof(argv[27]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); - printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg26 to 27: alpha, beta\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_bilinear_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0a2a9a10fe59abeac58d2ec3e43af012b51d82b --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F32; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32652fc6f3c8856505c3afc3c764fb0d08f23f48 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4797e24008984f281718ba4e8e6cdea28555714 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index 1574f5d18fb7f5f193cd4adb1c1a88616d7b1437..fe984b29d625b69ecf5555a3045f5b925752a1cd 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F32 = float; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F32; using BDataType = F32; @@ -31,6 +12,7 @@ using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; using EDataType = F32; +using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -40,237 +22,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -// clang-format on +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; using DeviceOpInstance = DeviceOpInstanceKKN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float scale = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - scale = std::stof(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg22: scale\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{scale}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 0>{}, - std::array, 0>{}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - Tensor empty_tensor(std::vector{}, std::vector{}); - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_scale_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83c8fec179d1cbdf4121c9c16a28a024458c0695 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = BF16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27e42e82037610de7f9d5bdfbc1f936e56d8733f --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 3dacc708877d838dd5fb0113b5d560885b0551ae..37374780b6864cf1b23391749f41ffc55236622a 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F64 = double; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F64; using BDataType = F64; @@ -31,6 +12,7 @@ using AccDataType = F64; using CShuffleDataType = F64; using DsDataType = ck::Tuple<>; using EDataType = F64; +using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -40,237 +22,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -// clang-format on +using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; using DeviceOpInstance = DeviceOpInstanceKKN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float scale = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - scale = std::stof(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg22: scale\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{scale}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 0>{}, - std::array, 0>{}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - Tensor empty_tensor(std::vector{}, std::vector{}); - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_scale_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dbeb231542539d989de97c3c64b16426876664d --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F32; +using CShuffleDataType = F64; +using DsDataType = ck::Tuple<>; +using EDataType = F64; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..7b9a01e0ae80b5bd5035fe213a3fe065e8a7a408 --- /dev/null +++ b/example/26_contraction/run_contraction_bilinear_example.inc @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_contraction_bilinear_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..65ca182da2cf7c5f97688f3a7c6d70b19cec016f --- /dev/null +++ b/example/26_contraction/run_contraction_scale_example.inc @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_contraction_scale_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt deleted file mode 100644 index 9cb2cd0766ea8c5aaa9154e87457763cc38372ab..0000000000000000000000000000000000000000 --- a/example/27_layernorm/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) - add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) -endif() diff --git a/example/27_layernorm/common.hpp b/example/27_layernorm/common.hpp deleted file mode 100644 index 62a71713df84351fea902cdcf3275786b60f5a23..0000000000000000000000000000000000000000 --- a/example/27_layernorm/common.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_common_util.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp deleted file mode 100644 index bb8b954f0acaf2e9d748104a05561f9cb02fc9a0..0000000000000000000000000000000000000000 --- a/example/27_layernorm/layernorm_fp16.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector -#include "run_layernorm_example.inc" - -int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_splitk_fp16.cpp b/example/27_layernorm/layernorm_splitk_fp16.cpp deleted file mode 100644 index e0378d028b343a6a70335cb9163370d1586ebf0a..0000000000000000000000000000000000000000 --- a/example/27_layernorm/layernorm_splitk_fp16.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // YScalarPerVector - -#include "run_layernorm_example.inc" - -int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/run_layernorm_example.inc b/example/27_layernorm/run_layernorm_example.inc deleted file mode 100644 index 95200b540aa9f9704c8fad785b352bdf779c4094..0000000000000000000000000000000000000000 --- a/example/27_layernorm/run_layernorm_example.inc +++ /dev/null @@ -1,96 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -template -int run_groupnorm_example() -{ - bool time_kernel = false; - - ck::index_t M = 1024; - ck::index_t N = 1024; - ck::index_t Stride = N; - - auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { - return HostTensorDescriptor({len}, {stride}); - }; - - auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { - using namespace ck::literals; - - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - }; - - Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); - Tensor gamma(f_host_tensor_descriptor1d(N, 1)); - Tensor beta(f_host_tensor_descriptor1d(N, 1)); - Tensor y(f_host_tensor_descriptor2d(M, N, Stride)); - - x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); - - x_dev.ToDevice(x.mData.data()); - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = DeviceInstance{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - {M, N}, - std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, - {0, 1}, - {0, 1}, - std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, - {1}, - 1e-4, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); - - if(!device_instance.IsSupportedArgument(argument_ptr.get())) - { - std::cout << "The runtime parameters are not supported" << std::endl; - return 1; - }; - - size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); - DeviceMem workspace_dev(workspace_sz); - device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - bool pass = true; - { - Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); - using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; - - ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); - } - - return (pass ? 0 : 1); -} diff --git a/example/27_layernorm2d_fwd/CMakeLists.txt b/example/27_layernorm2d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..639bd9c400e428599d606eb3e43fd3078e8382e2 --- /dev/null +++ b/example/27_layernorm2d_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_layernorm2d_fwd_fp16 layernorm2d_fwd_fp16.cpp) +add_example_executable(example_layernorm2d_fwd_splitk_fp16 layernorm2d_fwd_splitk_fp16.cpp) diff --git a/example/27_layernorm2d_fwd/common.hpp b/example/27_layernorm2d_fwd/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6c7b99f89f9930cdcb84bda674f96ab3a1237241 --- /dev/null +++ b/example/27_layernorm2d_fwd/common.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/27_layernorm2d_fwd/layernorm2d_fwd_fp16.cpp b/example/27_layernorm2d_fwd/layernorm2d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..20db24f56d38b22d20100a03f51f83fd21815125 --- /dev/null +++ b/example/27_layernorm2d_fwd/layernorm2d_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector +#include "run_layernorm_example.inc" + +int main() { return run_layernorm2d_fwd_example(); } diff --git a/example/27_layernorm2d_fwd/layernorm2d_fwd_splitk_fp16.cpp b/example/27_layernorm2d_fwd/layernorm2d_fwd_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a57082c79f2e5c56527c0396952ebba2e9c7bd9 --- /dev/null +++ b/example/27_layernorm2d_fwd/layernorm2d_fwd_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< + XDataType, + GammaDataType, + BetaDataType, + ComputeDataType, + YDataType, + SaveMeanInvStdDataType, + PassThrough, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterM + 32, // ClusterK + 1, // SliceM + 8, // SliceK + 1, // XYVectorDim (0=M, 1=K) + 8, // XScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 8, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 8, // BetaScalarPerVector + 8, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector + +#include "run_layernorm_example.inc" + +int main() { return run_layernorm2d_fwd_example(); } diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..23608a1eea01bffa20ebc1938ce4e7b4769e26a1 --- /dev/null +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +int run_layernorm2d_fwd_example() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + + Tensor x({M, N}); + Tensor gamma({N}); + Tensor beta({N}); + Tensor y({M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); + + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {M, N}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, 1}, + {0, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + {1}, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_y({M, N}); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif + } + + return (pass ? 0 : 1); +} diff --git a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt index 2fda1f62a9a94c69bde8f08e8c81167b2951c4d6..44ab16894ce054a4b25b2efd12ebac9152b3bd58 100644 --- a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt +++ b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) -endif() +add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 09c3e6c6085c233aeac87374e6996cade4dac142..ac54aebdc214a5607eafa91d5142443eeb0649f1 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,7 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - - if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) - endif() -endif() +add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 62233e535151e6b647e938de04b6225fe487b6c4..f556be887f9824c097adbefb5637b985ebf349cd 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Add; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceOpInstanceKKNN = @@ -55,43 +56,44 @@ using DeviceOpInstanceKKNN = NumDimK, ADataType, BDataType, - DsDataType, - EDataType, AccDataType, CShuffleDataType, + DsDataType, + EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ABSpec, - ABSpec, + ASpec, + BSpec, DESpec, - 256, + 1, 128, - 256, - 8, - 8, + 64, + 64, + 64, + 4, 16, 16, + 1, 4, - 4, - S<4, 64, 1>, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, - 8, - 8, - true, - S<4, 64, 1>, + 4, + 4, + false, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, - 8, - 8, - true, + 4, + 4, + false, 1, 1, - S<1, 32, 1, 8>, + S<1, 64, 1, 2>, 8>; using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -251,21 +253,6 @@ int main(int argc, char* argv[]) ck::index_t K0 = 2048; - // A[G0, G1, M0, M1, K0] - std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; - std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; - // B[G0, G1, N0, N1, K0] - std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; - std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; - - // D[G0, G1, M0, N0, M1, N1] - std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; - // E[G0, G1, M0, N0, M1, N1] - std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; - std::vector e_gs_ms_ns_strides{ - G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - if(argc == 1) { // use default case @@ -276,13 +263,43 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + G0 = std::stoi(argv[4]); + G1 = std::stoi(argv[5]); + M0 = std::stoi(argv[6]); + M1 = std::stoi(argv[7]); + N0 = std::stoi(argv[8]); + N1 = std::stoi(argv[9]); + K0 = std::stoi(argv[10]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); exit(0); } + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index e37413c09880de71b5bc10b15624fb98e75bf4ba..7acb1a1907bbe8fb7777e6e1bef4ce3a6200f516 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,46 +1,23 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) +add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() # USE_BITINT_EXTENSION_INT4 +add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - endif() - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) + +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) +endif() # USE_BITINT_EXTENSION_INT4 + +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/README.md b/example/30_grouped_conv_fwd_multiple_d/README.md index 739a0425a8c29419c3e31e433815c5230bd1a5cb..1165634e1afb85049b688efbcf3882901c955a77 100644 --- a/example/30_grouped_conv_fwd_multiple_d/README.md +++ b/example/30_grouped_conv_fwd_multiple_d/README.md @@ -4,7 +4,7 @@ arg1: verification (0=no, 1=yes) arg2: initialization (0=no init, 1=integer value, 2=decimal value) arg3: time kernel (0=no, 1=yes) Following arguments (depending on number of spatial dims): - Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) + Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D) G, N, K, C, , (ie Y, X for 2D) , (ie Hi, Wi for 2D) @@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims): ./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1 ``` -Result (MI100) -``` -in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192} -wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192} -bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} -residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} -out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<256, 128, 256, 16, Default> -``` diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e60ebee6e4fe2e08db432816b75bbbd4c388946c..39463614f5b34dde88992fe43e32b1985c975a38 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -12,7 +12,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 039d25029921491e7e67808e554b8cb3e6eb4745..ff873d26bcf927b9335fc9a2ded2d199ef3384e9 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = FP16; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp index 793324970e700748ac4b18e9fc429a24f70c007d..662a6f611b5f177251a4bafd32d4415fbc144992 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = I8; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 5494563fdd568d1f51ec9ee9042d094096e8fc8a..6f91d51a5fd7195472805a84d1f123b208fe9f67 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +#endif diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index eb242203eaa65ba7b8fca45c5f1e3e1b9d96c409..e3370b880bb29e78a88e5ed4ca1fcdcd14afa2f8 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -34,7 +34,7 @@ using ResidualLayout = typename LayoutSettingSelector::ResidualLayo template using DeviceConvFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InputLayout, WeightLayout, diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 360b2c8947b371b1502afb3af73c469f0d8c3636..3248c5fa4d3dd6f520bdf19eb1435a43683e6a29 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting @@ -42,41 +42,42 @@ using DeviceConvFwdInstance = OutputLayout, InKernelDataType, WeiKernelDataType, - ck::Tuple, - OutKernelDataType, AccDataType, CShuffleDataType, + ck::Tuple, + OutKernelDataType, InElementOp, WeiElementOp, OutElementOp, ConvSpec, // ConvForwardSpecialization GemmSpec, // GemmSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock 8, // K1 16, // MPerWMMA 16, // NPerWMMA 4, // MRepeat - 2, // NRepeat - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 1, // NRepeat + S<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 true, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN - 4, - 2, - S<1, 32, 1, 8>, + 1, + 1, + S<1, 16, 1, 8>, 8>; template @@ -277,9 +278,10 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) switch(conv_param.num_dim_spatial_) { - case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); - case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); - case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); + case 2: + return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); + // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); } return false; diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index 58ed69182e0622751dc3fa2823fcb7cfa5b279a8..da65bb1886fe11c8fdf888e772fde6892e7b3d29 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -3,7 +3,7 @@ template using DeviceConvFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InputLayout, WeightLayout, diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 2074520f8c06d527be4ff1ea6a87c388cfbd6391..8b648a7f734418a465ceef1e834eda5341cea615 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,26 +1,10 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) - endif() + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) endif() diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index d166214c3376cd90afd11796a4cb85ea28421861..2caee6b8dc5dc0a21bcc7b6526b463325cbce3d1 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm1 */ -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 0463bf6bd31483ada687f17ff175e0195404623c..519f53910698465fdfb15b1f61628e508c2941d9 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,24 +1,29 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) -endif() +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) +add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) +add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) +add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) add_custom_target(example_gemm_scale_softmax_gemm) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -endif() + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69ab5c5c0bbba9a4d7f7e08579fd94b71083c9a7 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceMHAFactory = + std::tuple, // ABlockTransfer MK -> K0 M K1 + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1 + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1 + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 8, + 1, + false, + 1, // CShuffleMWmmaPerWavePerShuffle + 2, // CShuffleNWmmaPerWavePerShuffle + S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec> // MaskingSpecialization + >; +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5cedb14c9298f964763564f2adeef9645d1642a --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41c6dff2dfefbff3abc5649f7518f13ef70a9d87 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +#define CK_MHA_USE_WAVE_1 +#define CK_MHA_USE_WAVE_2 +#define CK_MHA_USE_WAVE_4 +//#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 192, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec> +#endif +#ifdef CK_MHA_USE_WAVE_8 + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 192, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 64, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_cross_attention_wmma.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..955c25f0d143d9f41dead6d18b1ede00ecca99e8 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Grouped Query Attention, +Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit +Sanghai. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” +arXiv, May 22, 2023. https://doi.org/10.48550/arXiv.2305.13245. + +Example is GQA-4 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; +static constexpr ck::index_t QueryGroupNumber = 4; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +#include "run_grouped_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..112be07c4963e37c3b6c65ccbc5b437ee33fd569 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Multi-Query Attention +Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” arXiv.org, November 6, +2019. https://arxiv.org/abs/1911.02150v1. + +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm_MQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_MQA; + +#include "run_multi_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc new file mode 100644 index 0000000000000000000000000000000000000000..2e77479bccad137b00728b8280d8fbc6ea0f37f0 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + ck::index_t BatchCount = G0 * G1; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc new file mode 100644 index 0000000000000000000000000000000000000000..9ff4c56e0695aaf02fc88661722d7f410a928784 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t q_sequence_length = 256; + ck::index_t kv_sequence_length = 64; + ck::index_t head_dim = 80; + + // Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim, + // inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o = + // permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t batch_size = 2; + ck::index_t head_num = 8; + + float alpha = 1; + bool input_permute = true; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + q_sequence_length = std::stoi(argv[4]); + kv_sequence_length = std::stoi(argv[5]); + head_dim = std::stoi(argv[6]); + batch_size = std::stoi(argv[7]); + head_num = std::stoi(argv[8]); + + alpha = std::stof(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num\n"); + printf("arg9: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{batch_size, head_num, q_sequence_length, head_dim}; + std::vector a_gs_ms_ks_strides = + input_permute ? std::vector{q_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // A layout [batch_size, q_sequence_length, head_num, head_dim] + : std::vector{ + head_num * q_sequence_length * head_dim, + q_sequence_length * head_dim, + head_dim, + 1}; // A layout [batch_size, head_num, q_sequence_length, head_dim] + + std::vector b0_gs_ns_ks_lengths{ + batch_size, head_num, kv_sequence_length, head_dim}; + std::vector b0_gs_ns_ks_strides = + input_permute ? std::vector{kv_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // B0 layout [batch_size, kv_sequence_length, head_num, head_dim] + : std::vector{ + head_num * kv_sequence_length * head_dim, + kv_sequence_length * head_dim, + head_dim, + 1}; // B0 layout [batch_size, head_num, kv_sequence_length, head_dim] + + std::vector b1_gs_os_ns_lengths{ + batch_size, head_num, head_dim, kv_sequence_length}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{kv_sequence_length * head_num * head_dim, + head_dim, + 1, + head_num * head_dim} + // B1 layout [batch_size, kv_sequence_length, head_num, head_dim] + : std::vector{ + head_num * kv_sequence_length * head_dim, + kv_sequence_length * head_dim, + 1, + head_dim}; // B1 layout [batch_size, head_num, kv_sequence_length, head_dim] + + std::vector c_gs_ms_os_lengths{batch_size, head_num, q_sequence_length, head_dim}; + std::vector c_gs_ms_os_strides = + output_permute ? std::vector{q_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // C layout [batch_size, q_sequence_length, head_num, head_dim] + : std::vector{ + head_num * q_sequence_length * head_dim, + q_sequence_length * head_dim, + head_dim, + 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + std::vector kv_gs_ns_ks_lengths{ + batch_size, head_num, kv_sequence_length, 2, head_dim}; + std::vector kv_gs_ns_ks_strides = std::vector{ + kv_sequence_length * head_num * 2 * head_dim, + 2 * head_dim, + head_num * 2 * head_dim, + head_dim, + 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + // merge kv into a packed pointer send to device + b0_gs_ns_ks.ForEach( + [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); + b1_gs_os_ns.ForEach( + [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[3], 1, idx[2]) = self(idx); }); + DeviceMem q_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem kv_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() + + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + q_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + kv_device_buf.ToDevice(kv_gs_ns_ks.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeCrossAttnInvoker(); + auto argument = + gemm.MakeCrossAttnArgument(static_cast(q_device_buf.GetDeviceBuffer()), + static_cast(kv_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + batch_size, + q_sequence_length, + kv_sequence_length, + head_num, + head_dim, + alpha); + + // if(!gemm.IsSupportedArgument(argument)) + // { + // std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + // } + + ck::index_t BatchCount = batch_size * head_num; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(q_sequence_length) * kv_sequence_length * head_dim * 2 + + size_t(q_sequence_length) * kv_sequence_length * head_dim * 2) * + BatchCount; + std::size_t num_btype = (sizeof(ADataType) * q_sequence_length * head_dim + + sizeof(B0DataType) * head_dim * kv_sequence_length + + sizeof(B1DataType) * kv_sequence_length * head_dim + + sizeof(CDataType) * q_sequence_length * head_dim) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, q_sequence_length, head_dim}); + Tensor b0_g_k_n({BatchCount, head_dim, kv_sequence_length}); + Tensor b1_g_n_o({BatchCount, kv_sequence_length, head_dim}); + Tensor acc0_g_m_n( + {BatchCount, q_sequence_length, kv_sequence_length}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, + q_sequence_length, + kv_sequence_length}); // scratch object after softmax + Tensor c_g_m_o_host_result( + {BatchCount, q_sequence_length, head_dim}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(kv_sequence_length); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * head_num + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num + << ", q_sequence_length: " << q_sequence_length + << ", kv_sequence_length: " << kv_sequence_length << ", head_dim: " << head_dim + << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc new file mode 100644 index 0000000000000000000000000000000000000000..609d085299e62b5175c7af471b1ce87d75d4217d --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 64; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 4; + ck::index_t G1 = 16; + ck::index_t KV_head = QueryGroupNumber; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = + (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0 * QueryGroupNumber; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_gq_k_n({G0, QueryGroupNumber, K, N}); + Tensor b1_g0_gq_n_o({G0, QueryGroupNumber, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_gq_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_gq_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_gq_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_gq_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc new file mode 100644 index 0000000000000000000000000000000000000000..b05915c07fb681ef2e1693da9ab68e4b925e2731 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + ck::index_t KV_head = 1; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_1_k_n({G0, 1, K, N}); + Tensor b1_g0_1_n_o({G0, 1, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_1_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_1_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_1_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_1_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc new file mode 100644 index 0000000000000000000000000000000000000000..3fdaaebb0f5ff90c8399c65473773596da586ffb --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t sequence_length = 256; + ck::index_t head_dim = 80; + + // Output shape C[batch_size, sequence_length, head_num, head_dim]. Batch dim, outer dim, inner + // dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o = + // permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t batch_size = 2; + ck::index_t head_num = 8; + + float alpha = 1; + bool input_permute = true; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + sequence_length = std::stoi(argv[4]); + head_dim = std::stoi(argv[5]); + batch_size = std::stoi(argv[6]); + head_num = std::stoi(argv[7]); + + alpha = std::stof(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: sequence_length, head_dim, batch_size, head_num\n"); + printf("arg8: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector a_gs_ms_ks_strides = + input_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // A layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // A layout [batch_size, head_num, sequence_length, head_dim] + + std::vector b0_gs_ns_ks_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector b0_gs_ns_ks_strides = + input_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // B0 layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // B0 layout [batch_size, head_num, sequence_length, head_dim] + + std::vector b1_gs_os_ns_lengths{batch_size, head_num, head_dim, sequence_length}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{sequence_length * head_num * head_dim, + head_dim, + 1, + head_num * head_dim} + // B1 layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + 1, + head_dim}; // B1 layout [batch_size, head_num, sequence_length, head_dim] + + std::vector c_gs_ms_os_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector c_gs_ms_os_strides = + output_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // C layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // C layout [batch_size, head_num, sequence_length, head_dim] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + std::vector qkv_gs_ms_ks_lengths{ + batch_size, head_num, sequence_length, 3, head_dim}; + std::vector qkv_gs_ms_ks_strides = std::vector{ + sequence_length * head_num * 3 * head_dim, + 3 * head_dim, + head_num * 3 * head_dim, + head_dim, + 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + // merge qkv into a packed pointer send to device + a_gs_ms_ks.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); + b0_gs_ns_ks.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 1, idx[3]) = self(idx); }); + b1_gs_os_ns.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[3], 2, idx[2]) = self(idx); }); + DeviceMem qkv_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize() + + sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() + + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + qkv_device_buf.ToDevice(qkv_gs_ms_ks.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeSelfAttnInvoker(); + auto argument = + gemm.MakeSelfAttnArgument(static_cast(qkv_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + batch_size, + sequence_length, + head_num, + head_dim, + alpha); + + // if(!gemm.IsSupportedArgument(argument)) + // { + // std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + // } + + ck::index_t BatchCount = batch_size * head_num; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(sequence_length) * sequence_length * head_dim * 2 + + size_t(sequence_length) * sequence_length * head_dim * 2) * + BatchCount; + std::size_t num_btype = (sizeof(ADataType) * sequence_length * head_dim + + sizeof(B0DataType) * head_dim * sequence_length + + sizeof(B1DataType) * sequence_length * head_dim + + sizeof(CDataType) * sequence_length * head_dim) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, sequence_length, head_dim}); + Tensor b0_g_k_n({BatchCount, head_dim, sequence_length}); + Tensor b1_g_n_o({BatchCount, sequence_length, head_dim}); + Tensor acc0_g_m_n( + {BatchCount, sequence_length, sequence_length}); // scratch object after gemm0 + Tensor a1_g_m_n( + {BatchCount, sequence_length, sequence_length}); // scratch object after softmax + Tensor c_g_m_o_host_result( + {BatchCount, sequence_length, head_dim}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(sequence_length); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * head_num + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num + << ", sequence_length: " << sequence_length << ", head_dim: " << head_dim + << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ec1bc933f75d6af399f546a84ef221b910e4e22 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +#define CK_MHA_USE_WAVE_1 +#define CK_MHA_USE_WAVE_2 +#define CK_MHA_USE_WAVE_4 +//#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 192, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec> +#endif +#ifdef CK_MHA_USE_WAVE_8 + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 192, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_self_attention_wmma.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/34_batchnorm/batchnorm_infer_impl.hpp b/example/34_batchnorm/batchnorm_infer_impl.hpp index d0b545b2a31d2fc095eece12ce273db823a5d2ac..ac6a0e451d765d6a75c87e5403e971729abc4dd0 100644 --- a/example/34_batchnorm/batchnorm_infer_impl.hpp +++ b/example/34_batchnorm/batchnorm_infer_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,7 +10,7 @@ #include "ck/utility/sequence.hpp" #include "ck/utility/tuple.hpp" #include "ck/utility/reduction_operator.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "batchnorm_common.hpp" @@ -54,7 +54,12 @@ int bnorm_infer( ck::Tuple, // y NormalizeInInfer, Rank, - 2, // MPerthread + 64, // BlockSize + 32, // MPerBlock + 32, // NPerBlock + 4, // MPerthread + 4, // NPerthread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder ck::Sequence<1, 1, 1, 1, 1>, // x, mean, variance, scale, bias ck::Sequence<1>>; // scalarPerVector: y diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 251a9b93c57e3b24ae9a1305bbe7607eb696817f..904006ba3600b5ca5a3dca0f6a4e1b707688ebdb 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,28 +1,29 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_splitK_gemm_xdl) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) - endif() - set(target 1) - endif() -endforeach() +add_custom_target(example_splitK_gemm_xdl) +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) + +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + +add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) + +add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) + +add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) + +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() + +add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_fp16 gemm_xdl_splitk_reduce_multi_d_fp16.cpp) +add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_bf16 gemm_xdl_splitk_reduce_multi_d_bf16.cpp) +add_example_executable(example_gemm_xdl_splitk_reduce_bf16A_i8B gemm_xdl_splitk_reduce_bf16A_i8B.cpp) + +add_example_executable(example_gemm_xdl_splitk_reduce_bfp16 gemm_xdl_splitk_reduce_bf16.cpp) diff --git a/example/35_splitK_gemm/common.hpp b/example/35_splitK_gemm/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..64fadae9e5c5e20bff445e5826fcae9ea6e0830c --- /dev/null +++ b/example/35_splitK_gemm/common.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +struct ProblemSizeSplitK final +{ + ck::index_t M = 256; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideC = N; + + ck::index_t KBatch = 2; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = true; +}; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.KBatch = std::stoi(argv[10]); + } + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: KBatch" << std::endl; + return false; + } + + return true; +} diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ceb1d09ef9c7c51129694444600d18ccf487fdc --- /dev/null +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = ck::bhalf_t; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 64, + 8, 4, + 32, 32, + 2, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b5aeff65d66f561305b03bd7da7f7ce217e7f076 --- /dev/null +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = int8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 64, + 8, 4, + 32, 32, + 2, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb84f2a416693cb7cb4f4bc638ba49f583ebed5e --- /dev/null +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 64, + 8, 4, + 32, 32, + 2, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ab8f77dc4a534f207216c44bdb15d0523331274 --- /dev/null +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; +using ReduceDataType = float; +using D0DataType = ck::half_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 64, + 8, 4, + 32, 32, + 2, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v2, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..9635993d63b5355c1405e03f90ef9762edcacea0 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto StrideD0 = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } +#if 0 + printf("B matrix:\n"); + for (int in = 0; in < N; in++) + { + for (int ik = 0; ik < K; ik++) + { + printf("%02x ", *(reinterpret_cast(&b_k_n(ik,in)))); + if(ik%8==7) printf("|"); + } + printf("\n"); + } +#endif + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + auto get_argment = [&]() { + if constexpr(DsDataType::Size() > 0) + { + return gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + {d0_m_n_device_buf.GetDeviceBuffer()}, + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + {StrideD0}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + else + { + return gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + {}, + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + }; + auto argument = get_argment(); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer(), StreamConfig{}); + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if constexpr(DsDataType::Size() > 0) + { + c_m_n_host_result.ForEach( + [&](auto& self, auto idx) { c_element_op(self(idx), self(idx), d0_m_n(idx)); }); + } + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e9bd5c552dea937deb78859b94b731eec6ec03b0..e3690984abc08cff2e5a450f9f6b6ce13217b883 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con if(config.time_kernel) { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp similarity index 100% rename from example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp rename to example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index 74fb16e15b019e145cdf13549254be1b09fbd8ca..dc54bc30efeb9c5233e2af1d55381fc102d8e3f5 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -42,7 +42,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::KPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b93639e6c11dc430e249645e39b944d4637e99c3 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Default, ADataType, BDataType>; + +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97a3f89e5ef3f90f5f431a20726a19d3de1d11d6 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#define DIRECT_LOAD 1 + +#if DIRECT_LOAD +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#endif + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +#if DIRECT_LOAD + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle_LdsDirectLoad + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; +// clang-format on + +#else + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#endif + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index 36bb5720d53c1a738323f50d0e855d50d8996047..a9be3a7108fee6075dd7ae81f7536b072e655c0b 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 36dcf58d7044b3bf54d3e318e5d095a40d70255d..ff1282f3c78d9e227cc8a2132b1cd81fa7c27187 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. /* Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] @@ -60,14 +60,14 @@ struct AddAddRelu { const ck::half_t x = c + d0 + d1; - ck::tensor_operation::element_wise::Relu{}.template operator()(e, x); + ck::tensor_operation::element_wise::Relu{}.operator()(e, x); } __host__ __device__ void operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const { const float x = c + (d0 + d1); - ck::tensor_operation::element_wise::Relu{}.template operator()(e, x); + ck::tensor_operation::element_wise::Relu{}.operator()(e, x); } }; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 3821f8aaca74a3b028fc3917c10e5f279375c3fe..ce951f6353ceaf4ca4b2ae0ceac8072ac882a1c9 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,15 +1,13 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) - add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) - - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) - set(target 1) - endif() -endforeach() -endif() +add_custom_target(example_grouped_conv_bwd_data) + +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) + +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8) + +add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) + +add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index ca824b1075ced161760892fe6de53767b1079ccc..8a0474156ccb93bc01323d52d87fc3099e9ab332 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -35,6 +34,8 @@ static constexpr auto ConvBwdDataDefault = using FP16 = ck::half_t; using FP32 = float; +using FP8 = ck::f8_t; +using BF8 = ck::bf8_t; struct ExecutionConfig final { diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp deleted file mode 100644 index a3533bb4cc4b4874f0e7ae9f333347913d981ba0..0000000000000000000000000000000000000000 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using OutDataType = FP16; -using WeiDataType = FP16; -using AccDataType = FP32; -using CShuffleDataType = FP16; -using BiasDataType = FP16; // bias -using InDataType = FP16; - -using OutLayout = ck::tensor_layout::convolution::GNHWK; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using BiasLayout = ck::Tuple; -using InLayout = ck::tensor_layout::convolution::GNHWC; - -using OutElementOp = PassThrough; -using WeiElementOp = PassThrough; -using InElementOp = ck::tensor_operation::element_wise::AddRelu; - -// clang-format off -using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 -// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| -// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; -// clang-format on - -#include "run_grouped_conv_bwd_data_bias_relu_example.inc" - -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_bias_relu_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c28e25e019d41d2a3cb55e410286ab616dc4722 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using BiasDataType = FP16; // bias +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::Tuple; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = ck::tensor_operation::element_wise::AddRelu; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_conv_bwd_data_bias_relu_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_bias_relu_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp deleted file mode 100644 index fb688b6f3f15383c2920a12e8902c2d4cf05a109..0000000000000000000000000000000000000000 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using OutDataType = FP16; -using WeiDataType = FP16; -using AccDataType = FP32; -using CShuffleDataType = FP16; -using DsDataType = ck::Tuple<>; -using InDataType = FP16; - -using OutLayout = ck::tensor_layout::convolution::GNHWK; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using DsLayout = ck::Tuple<>; -using InLayout = ck::tensor_layout::convolution::GNHWC; - -using OutElementOp = PassThrough; -using WeiElementOp = PassThrough; -using InElementOp = PassThrough; - -// clang-format off -using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 -// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| -// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; -// clang-format on - -#include "run_grouped_conv_bwd_data_example.inc" - -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3e3ae7edbde107fce0e7c1eeea3667472c7e7ce6 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "common.hpp" +#include "ck/host_utility/device_prop.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < 2,OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1554412b1adca12cd49231a1c59cd6f34462ebf --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41023ef82ae94266db18da8bcd2e54910315918c --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; +using AComputeType = BF8; +using BComputeType = FP8; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| Loop| ACompute| BCompute| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, AComputeType, BComputeType>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 5b43de9725aa3d535851926884fc6790233c6dbb..8b850c89a935ed194cad918a2bdbcf4b0c77d739 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,11 +1,10 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_custom_target(example_permute) +add_custom_target(example_permute) - add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) - add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) - add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) +add_example_dependencies(example_permute example_permute_1xHxW_fp16) - add_dependencies(example_permute example_permute_1xHxW_fp16) - add_dependencies(example_permute example_permute_NxHxW_fp16) - add_dependencies(example_permute example_permute_HxWx4_fp16) -endif() +add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) +add_example_dependencies(example_permute example_permute_NxHxW_fp16) + +add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_example_dependencies(example_permute example_permute_HxWx4_fp16) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 55464957ac62e831e332e6b88b97e283e22866c8..991c1e464be931cfcfc20df599a26233d8e6abc1 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,28 +1,17 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - if(DL_KERNELS) - # Conv perlayer quantization - add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) - # Conv perchannel quantization - add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - # Conv + bias + relu perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - # Conv + bias + relu perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - # Conv + bias + tanh perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - # Conv + bias + tanh perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) - endif() -endif() \ No newline at end of file +# Conv perlayer quantization +add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) +# Conv perchannel quantization +add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) +# Conv + bias + relu perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) +# Conv + bias + relu perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index 06c839e4e226e082e9754f12ce39cececbd36269..8c0049b0fa84dd836db25d2751c3f0d2aeca07a3 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -33,7 +33,7 @@ template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 7a9b42d39f971ab1c36476f4c0ead172cb6d2606..e18c123f7c3a18a5d6017bc5cf045c2060da18e8 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -31,7 +31,7 @@ template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index 3495636297d5d83fb3ccc7d236e6135de9b715fd..53f810cc9ec078c7ee4da9ee3961fc7cac9f84a1 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -31,7 +31,7 @@ template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index 2611337254212b326f6eef540c6a4eca5cd19261..9db6e201ddd14704fd7b4903d4510a2c91fafee4 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -26,7 +26,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 085d90a9e9c5cc773c9935ba954853444e35740f..8ab56b21a638c518d6d2584a6ffb3b94a6130470 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,26 +1,10 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) - endif() + add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) endif() diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 80f6e9ae05712b1df4b05b1bb75fbff62d9ca70a..cf7b1ce3a813303841c47ae8b3084348563b41bb 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt deleted file mode 100644 index bc2246a2bfd39ad933f42b5b5f47d135f6278f76..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) - add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) - add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) -endif() diff --git a/example/42_groupnorm/common.hpp b/example/42_groupnorm/common.hpp deleted file mode 100644 index c8f91eb53b5b943bfbdd264b842cf7762a176b82..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/common.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" - -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_common_util.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp deleted file mode 100644 index cc107b63dcd87735ba40ebbe7633c61993406166..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; - -struct YElementOp -{ - template - __host__ __device__ void operator()(T& y, const T& x) const - { - static_assert(ck::is_same::value || ck::is_same::value || - ck::is_same::value, - "Data type is not supported by this operation!"); - - T a; - - ck::tensor_operation::element_wise::Sigmoid{}(a, x); - - y = x * a; - }; -}; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -#include "run_groupnorm_example.inc" - -int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_splitk_fp16.cpp b/example/42_groupnorm/groupnorm_splitk_fp16.cpp deleted file mode 100644 index 057b240a63fc38217c095b6f3a058e1d4db6ea27..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/groupnorm_splitk_fp16.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using YElementOp = ck::tensor_operation::element_wise::Swish; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // OutScalarPerVector - -#include "run_groupnorm_example.inc" - -int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_swish_fp16.cpp b/example/42_groupnorm/groupnorm_swish_fp16.cpp deleted file mode 100644 index 363f22ed4c0015c9a96aa98a8b6b5302978e3d25..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/groupnorm_swish_fp16.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using YElementOp = ck::tensor_operation::element_wise::Swish; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -#include "run_groupnorm_example.inc" - -int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/run_groupnorm_example.inc b/example/42_groupnorm/run_groupnorm_example.inc deleted file mode 100644 index 16065c8d46a42433416e21e9607dc3fcb708d817..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/run_groupnorm_example.inc +++ /dev/null @@ -1,113 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -int run_groupnorm_example(int argc, char* argv[]) -{ - ck::index_t N = 32; - ck::index_t H = 16; - ck::index_t W = 16; - ck::index_t G = 64; - ck::index_t C = 128; - - if(argc == 1) - { - // use default case - } - else if(argc == 6) - { - N = std::stoi(argv[1]); - H = std::stoi(argv[2]); - W = std::stoi(argv[3]); - G = std::stoi(argv[4]); - C = std::stoi(argv[5]); - } - else - { - std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; - - return 1; - } - - Tensor x({N, H, W, G, C}); - Tensor y({N, H, W, G, C}); - Tensor gamma({G, C}); - Tensor beta({G, C}); - - ck::utils::FillUniformDistribution{0.f, 1.f}(x); - ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); - ck::utils::FillUniformDistribution{0.f, 1.f}(beta); - - DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); - - x_dev.ToDevice(x.mData.data()); - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - const auto y_element_op = YElementOp{}; - - auto device_instance = DeviceInstance{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - {N, H, W, G, C}, - std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, - {0, 0, 0, C, 1}, - {0, 0, 0, C, 1}, - std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, - {1, 2, 4}, // reduction dimension: [H, W, C] - 1e-6, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - nullptr, - nullptr, - y_element_op); - - if(!device_instance.IsSupportedArgument(argument_ptr.get())) - { - std::cout << "The runtime parameters are not supported" << std::endl; - return 1; - }; - - size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); - DeviceMem workspace_dev(workspace_sz); - device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); - - std::size_t num_btype = sizeof(XDataType) * N * H * W * G * C + - sizeof(YDataType) * N * H * W * G * C + sizeof(GammaDataType) * G * C + - sizeof(BetaDataType) * G * C; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s, " - << device_instance.GetTypeString() << std::endl; - - bool pass = true; - { - Tensor host_y({N, H, W, G, C}); - using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; - - ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); - } - - return (pass ? 0 : 1); -} diff --git a/example/42_groupnorm_fwd/CMakeLists.txt b/example/42_groupnorm_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7d08baccd0d6938797f54aa83c4afabebdb26a3a --- /dev/null +++ b/example/42_groupnorm_fwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_groupnorm_fwd_sigmoid_mul_fp16 groupnorm_fwd_sigmoid_mul_fp16.cpp) +add_example_executable(example_groupnorm_fwd_splitk_fp16 groupnorm_fwd_splitk_fp16.cpp) +add_example_executable(example_groupnorm_fwd_swish_fp16 groupnorm_fwd_swish_fp16.cpp) diff --git a/example/42_groupnorm_fwd/common.hpp b/example/42_groupnorm_fwd/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..038a8c0f1fea900754fa1ed27c6b5210e92072e9 --- /dev/null +++ b/example/42_groupnorm_fwd/common.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp" diff --git a/example/42_groupnorm_fwd/groupnorm_fwd_sigmoid_mul_fp16.cpp b/example/42_groupnorm_fwd/groupnorm_fwd_sigmoid_mul_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15c02b8213c9f70bf919ec4c6eb083d7cf1573bf --- /dev/null +++ b/example/42_groupnorm_fwd/groupnorm_fwd_sigmoid_mul_fp16.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; + +#define SAVE_MEAN_INV_STD + +struct YElementOp +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + X a; + + ck::tensor_operation::element_wise::Sigmoid{}(a, x); + + y = ck::type_convert(x * a); + }; +}; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_fwd_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_fwd_example(argc, argv); } diff --git a/example/42_groupnorm_fwd/groupnorm_fwd_splitk_fp16.cpp b/example/42_groupnorm_fwd/groupnorm_fwd_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..37fdf1b253d677c7312c09192bbbe1c521768beb --- /dev/null +++ b/example/42_groupnorm_fwd/groupnorm_fwd_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD + +using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< + XDataType, + GammaDataType, + BetaDataType, + ComputeDataType, + YDataType, + SaveMeanInvStdDataType, + YElementOp, + Rank, + NumReduceDim, + 256, // BlockSize + 1, // ClusterM + 256, // ClusterK + 1, // SliceM + 16, // SliceK + 1, // SrcVecDim (0=M, 1=K) + 2, // SrcScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 2, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 2, // BetaScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_fwd_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_fwd_example(argc, argv); } diff --git a/example/42_groupnorm_fwd/groupnorm_fwd_swish_fp16.cpp b/example/42_groupnorm_fwd/groupnorm_fwd_swish_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d17264cef8f1c42c02ed94b76bef22db7059b53 --- /dev/null +++ b/example/42_groupnorm_fwd/groupnorm_fwd_swish_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_fwd_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_fwd_example(argc, argv); } diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..853ff791a6e8d0d07cf1584243e267f512291530 --- /dev/null +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +int run_groupnorm_fwd_example(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + N = std::stoi(argv[1]); + H = std::stoi(argv[2]); + W = std::stoi(argv[3]); + G = std::stoi(argv[4]); + C = std::stoi(argv[5]); + } + else + { + std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; + + return 1; + } + + Tensor x({N, H, W, G, C}); + Tensor y({N, H, W, G, C}); + Tensor gamma({G, C}); + Tensor beta({G, C}); + Tensor save_mean({N, G}); + Tensor save_inv_std({N, G}); + + ck::utils::FillUniformDistribution{0.f, 1.f}(x); + ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); + ck::utils::FillUniformDistribution{0.f, 1.f}(beta); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + const auto y_element_op = YElementOp{}; + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {N, H, W, G, C}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, 0, 0, C, 1}, + {0, 0, 0, C, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + {1, 2, 4}, // reduction dimension: [H, W, C] + 1e-6, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + y_element_op); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); + + std::size_t num_btype = sizeof(XDataType) * N * H * W * G * C + + sizeof(YDataType) * N * H * W * G * C + sizeof(GammaDataType) * G * C + + sizeof(BetaDataType) * G * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s, " + << device_instance.GetTypeString() << std::endl; + + bool pass = true; + { + Tensor host_y({N, H, W, G, C}); + Tensor host_save_mean(HostTensorDescriptor{N, G}); + Tensor host_save_inv_std(HostTensorDescriptor{N, G}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + y_element_op, + {N, H, W, G, C}, + 1e-6); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif + } + + return (pass ? 0 : 1); +} diff --git a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt index 7e070f5357999b73ce647489c37613a4e52329f8..c29f18f1627ef20fe69cb3751d3a7766ffbd236c 100644 --- a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt +++ b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) -endif() +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 877a82031598ba63c3c331b0bf7df34a2c07d5a7..867493465d5bf82f262342c5be181312069fbc65 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1,4 +1,8 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) - add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) -endif() +add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) +add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permute_4D_fp32_row.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) +add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) +add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) +add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) +add_example_executable(elementwise_scale_permute_amax_2D_fp16_fp8 elementwise_scale_permute_amax_2D_fp16_fp8.cpp) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8819bb65e6f3de262a49ee2e9aefc248a95616f2 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 +using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise:: + BinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + BinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 2> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 2ceda86839abc37bd71fc1179d12be51d42e9a20..3ea1aa4bf860bce93156b50646603c4ff1ac037b 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -1,9 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -17,28 +22,20 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, - ck::Tuple, - PassThrough, - 4, - 8, - ck::Sequence<8>, - ck::Sequence<1>>; - -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // Elementwise + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq int main() { @@ -47,18 +44,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -69,9 +54,22 @@ int main() 1, static_cast(nhwc[2] * nhwc[3]), static_cast(nhwc[3])}; - ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -99,15 +97,20 @@ int main() std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - bool pass = true; if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp deleted file mode 100644 index 6b94a5d46f5ad20474900365ab6397ef48d87769..0000000000000000000000000000000000000000 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ /dev/null @@ -1,130 +0,0 @@ -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" - -using F16 = ck::half_t; - -using ADataType = F16; -using BDataType = F16; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwise2dImpl, - ck::Tuple, - PassThrough, - 3, // NumDim_M - 1, // NumDim_N - 8, - 8, - ck::Sequence<8>, - ck::Sequence<8>>; - -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - const std::vector& shape_nchw, - Functor functor) -{ - for(std::size_t n = 0; n < shape_nchw[0]; ++n) - for(std::size_t c = 0; c < shape_nchw[1]; ++c) - for(std::size_t h = 0; h < shape_nchw[2]; ++h) - for(std::size_t w = 0; w < shape_nchw[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - -int main() -{ - bool do_verification = true; - bool time_kernel = true; - - const int N = 120; - const int C = 128; - const int H = 32; - const int W = 1024; - - /**const int N = 120; - const int H = 32; - const int W = 64; - - const int C = 128;**/ - - std::vector nchw = {N, C, H, W}; - std::vector nhwc = {N, H, W, C}; - - Tensor a(nchw); - Tensor b(nhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - // LogRangeAsType(std::cout << "Tensor a : ", a.mData, ",") << std::endl; - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; - - std::array ab_lengths{N, H, W, C}; - - std::array a_strides = {C * H * W, W, 1, H * W}; - std::array b_strides = {H * W * C, W * C, C, 1}; - - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); - - if(!broadcastPermute.IsSupportedArgument(argument.get())) - { - throw std::runtime_error( - "The runtime parameters seems not supported by the device instance, exiting!"); - }; - - std::cout << "A (nchw): " << a.mDesc << std::endl; - std::cout << "B (nhwc): " << b.mDesc << std::endl; - - auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); - float ave_time = - broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; - - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - bool pass = true; - - if(do_verification) - { - b_device_buf.FromDevice(b.mData.data()); - // LogRangeAsType(std::cout << "Tensor b : ", b.mData, ",") << std::endl; - - Tensor host_b(nhwc); - host_elementwise4D, Tensor, PassThrough>( - host_b, a, nchw, PassThrough{}); - - // LogRangeAsType(std::cout << "Host b : ", host_b.mData, ",") << std::endl; - pass &= - ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); - } - - return pass ? 0 : 1; -} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13c67fce059b04d9d4060b521181529a9dd44e5a --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryScaleSquare, // UnaryScaleSquare + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + std::array ab_lengths; + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + float scale = 1.f; + auto i = 0; + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); + for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) + for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) + for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) + for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n) + { + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + + (h * nchw[3]) + w] = i; + i = dis(gen); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = + (2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a0f6fec10567567f1e43418292cbc6c3e576579 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryScaleSquare, // UnaryScaleSquare + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::vector nhwc = {16, 32, 64, 128}; + + std::array ab_lengths; + std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + std::array b_strides = {static_cast(nhwc[1] * nhwc[2] * nhwc[3]), + 1, + static_cast(nhwc[2] * nhwc[3]), + static_cast(nhwc[3])}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc664186bef8c9468b44c68394f75de48e1118b3 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F32; +using BDataType = F32; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryScaleSquare, // UnaryScaleSquare + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + float scale = 1.f; + auto i = 0; + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); + for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) + for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) + for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) + for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n) + { + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + + (h * nchw[3]) + w] = i; + i = dis(gen); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0c416318a0908b2c4bbe33e2de728da0d75acf1 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F32; +using BDataType = F32; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryScaleSquare, // UnaryScaleSquare + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::vector nhwc = {16, 32, 64, 128}; + + std::array ab_lengths; + std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + std::array b_strides = {static_cast(nhwc[1] * nhwc[2] * nhwc[3]), + 1, + static_cast(nhwc[2] * nhwc[3]), + static_cast(nhwc[3])}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9431a8cde4cb62e7ca7407a8c4b2c74370877612 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/utility/reduction_enums.hpp" + +using F16 = ck::half_t; +using F32 = float; +using F8 = ck::f8_t; + +using InputDataType = F16; +using ScaleDataType = F32; +using OutputDataType = F8; + +static constexpr ck::index_t NumDim = 2; + +constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename ck::reduce_binary_operator::opType; + +struct ScalePassThrough +{ + ScalePassThrough(const float alpha = 1.f) : alpha_(alpha) {} + + __host__ __device__ constexpr void + operator()(OutputDataType& y0, OutputDataType& y1, const InputDataType& x0) const + { + y0 = ck::type_convert(ck::type_convert(x0) * alpha_); + y1 = y0; + } + + const ScaleDataType alpha_; +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryAbs = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + ScalePassThrough, // Elementwise + NumDim, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8, 1>>; // OutScalarPerVectorSeq + +using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceMultiBlock; // OutDstVectorSize + +void reference_scale_permute_amax(Tensor& input, + Tensor& host_output_scaled_casted_transposed, + Tensor& host_output_scaled_casted, + Tensor& host_output_amax, + const float scale) +{ + ScalePassThrough out_element_op(scale); + const ck::index_t M = input.GetLengths()[0]; + const ck::index_t K = input.GetLengths()[1]; + for(ck::index_t m = 0; m < M; m++) + { + for(ck::index_t k = 0; k < K; k++) + { + OutputDataType y0, y1; + out_element_op(y0, y1, input(m, k)); + + host_output_scaled_casted(m, k) = y0; + host_output_scaled_casted_transposed(m, k) = y1; + const OutputDataType y_fabs = + ck::type_convert(ck::math::abs(ck::type_convert(y0))); + host_output_amax(0) = ck::type_convert(ck::math::max( + ck::type_convert(y_fabs), ck::type_convert(host_output_amax(0)))); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + bool time_kernel = true; + + const float scale = 2.f; + + ck::index_t M = 1024; + ck::index_t K = 1024; + + if(argc == 3) + { + M = std::stoi(argv[1]); + K = std::stoi(argv[2]); + } + + std::array dims = {M, K}; + std::array in_strides = {K, 1}; + std::array out_strides = {1, M}; + + Tensor input(dims, in_strides); + Tensor output_scaled_casted_transposed(dims, out_strides); + Tensor output_scaled_casted(dims, in_strides); + Tensor output_amax({1}); + + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem input_dev_buf(sizeof(InputDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem output_scaled_casted_transposed_dev_buf( + sizeof(OutputDataType) * output_scaled_casted_transposed.mDesc.GetElementSpaceSize()); + DeviceMem output_scaled_casted_dev_buf(sizeof(OutputDataType) * + output_scaled_casted.mDesc.GetElementSpaceSize()); + DeviceMem output_amax_dev_buf(sizeof(OutputDataType) * output_amax.mDesc.GetElementSpaceSize()); + + input_dev_buf.ToDevice(input.mData.data()); + + std::array inputs = {input_dev_buf.GetDeviceBuffer()}; + std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), + output_scaled_casted_dev_buf.GetDeviceBuffer()}; + + std::cout << "Input: " << input.mDesc << std::endl; + std::cout << "Scale: " << scale << std::endl; + std::cout << "Output scaled casted transposed: " << output_scaled_casted_transposed.mDesc + << std::endl; + std::cout << "Output scaled casted: " << output_scaled_casted.mDesc << std::endl; + std::cout << "Output amax: " << output_amax.mDesc << std::endl; + + auto launch_transpose_scale = [&]() { + auto transposeScale = DeviceElementwisePermuteInstance{}; + auto argument = transposeScale.MakeArgumentPointer(dims, + {in_strides}, + {out_strides, in_strides}, + inputs, + outputs, + ScalePassThrough{scale}); + + if(!transposeScale.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto transposeScale_invoker_ptr = transposeScale.MakeInvokerPointer(); + return transposeScale_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + }; + + auto launch_reduce = [&]() { + auto reduce = DeviceReduceInstance{}; + auto reduce_argument_ptr = + reduce.MakeArgumentPointer(dims, + in_strides, + {1}, // Output Lengths + {1}, // Output Strides + {0, 1}, // Reduce Dims + static_cast(1.f), + static_cast(0.f), + output_scaled_casted_dev_buf.GetDeviceBuffer(), + nullptr, + output_amax_dev_buf.GetDeviceBuffer(), + nullptr, + UnaryAbs{}, + PassThrough{}); + + if(!reduce.IsSupportedArgument(reduce_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + return invoker_ptr->Run(reduce_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + }; + + float ave_time = launch_transpose_scale(); + ave_time += launch_reduce(); + std::cout << "Perf: " << ave_time << " ms" << std::endl; + bool pass = true; + + if(do_verification) + { + Tensor host_output_scaled_casted_transposed(dims, out_strides); + Tensor host_output_scaled_casted(dims, in_strides); + Tensor host_output_amax({1}); + + reference_scale_permute_amax(input, + host_output_scaled_casted_transposed, + host_output_scaled_casted, + host_output_amax, + scale); + + output_scaled_casted_transposed_dev_buf.FromDevice( + output_scaled_casted_transposed.mData.data()); + output_scaled_casted_dev_buf.FromDevice(output_scaled_casted.mData.data()); + output_amax_dev_buf.FromDevice(output_amax.mData.data()); + + pass &= ck::utils::check_err(output_scaled_casted_transposed.mData, + host_output_scaled_casted_transposed.mData, + "Error: Incorrect results scaled transposed", + 1e-3, + 1e-3); + pass &= ck::utils::check_err(output_scaled_casted.mData, + host_output_scaled_casted.mData, + "Error: Incorrect results scaled", + 1e-3, + 1e-3); + pass &= ck::utils::check_err( + output_amax.mData, host_output_amax.mData, "Error: Incorrect results amax", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..050300eed2311270871368a493faf3212e44a334 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2 +using TrinaryAddUnaryScaleSquare = + ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + TrinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 3> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor& a2 = as[2]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + float gamma = 4.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a2.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + a2_device_buf.ToDevice(a2.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + a2_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "A2 (nchw): " << a2.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + ref_invoker.Run(ref_argument); + + const double threshold = std::pow(2, -10) * 2; + b_device_buf.FromDevice(b.mData.data()); + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold); + } + + return pass ? 0 : 1; +} diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 76361f87a5b58071b149cab505f7c0485abfadab..c02d540983ed17f14f3376ab158308147337fc11 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -167,20 +167,31 @@ int main() XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + YElementwiseOperation{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); diff --git a/example/46_gemm_add_multiply/CMakeLists.txt b/example/46_gemm_add_multiply/CMakeLists.txt index cf7d81f895b82244a0ae7ac127e9df4c9a03324f..bfe057e8da677f30844a742b032522a2a61a947c 100644 --- a/example/46_gemm_add_multiply/CMakeLists.txt +++ b/example/46_gemm_add_multiply/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) - endif() - add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) -endif() +add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) +add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) diff --git a/example/46_gemm_add_multiply/README.md b/example/46_gemm_add_multiply/README.md index ee5cdee3659e7b6c3ba9bc083c48f91fcd66e581..e2de4696f37f5fe4fdd1a0a9a1537ae2b3899ba1 100644 --- a/example/46_gemm_add_multiply/README.md +++ b/example/46_gemm_add_multiply/README.md @@ -8,19 +8,3 @@ #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" ./bin/example_gemm_add_multiply_dl_fp16 1 1 1 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} -d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} -d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.e_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 3.99904 ms, 32.22 TFlops, 31.9913 GB/s, DeviceGemmMultipleD_Dl<256, 128, 128, 16, 2, 4, 4, 1> -``` diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index 14432f6e23d4adbf09e0df6c805a68f1a5571f69..df1956ca62c7968ee3b66318a82358d177c80188 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp similarity index 100% rename from example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp rename to example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp diff --git a/example/48_pool3d_fwd/CMakeLists.txt b/example/48_pool3d_fwd/CMakeLists.txt index f821f2c97b0e7f611d500c07ff1297f70ddd6251..492cb4d37ec8e5e851106f6611e09e5ff9d54d7a 100644 --- a/example/48_pool3d_fwd/CMakeLists.txt +++ b/example/48_pool3d_fwd/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) -endif() +add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 39032fa12378aa1ac5a00783384f130246b5dd93..788f38ec52d06acf2f4cf891cf29191d7948a39d 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -32,6 +32,8 @@ std::vector f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template ; // InSrcOutDstVectorSize using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: - DeviceIndexPoolBwdImpl; + DeviceMaxPoolBwdImpl; const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; const ck::index_t Xs = (X - 1) * window_dilation_w + 1; @@ -155,7 +155,8 @@ bool maxpool_bwd_test(bool do_verification, dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), window_spatial_lengths, - window_strides); + window_strides, + window_dilations); if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) { diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt index eca410008fde33b6236c2c14c10650a74640c53d..1b0020ebcf65a87ad39b004c149385cdd7de89ae 100644 --- a/example/50_put_element/CMakeLists.txt +++ b/example/50_put_element/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_put_element_fp16 put_element_fp16.cpp) -endif() +add_example_executable(example_put_element_fp16 put_element_fp16.cpp) diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp index 394f046b1e9d2136e42001f366d7423f1f5e8600..032828f7bcd47889502a619cb273c302afdc9bfc 100644 --- a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp @@ -26,6 +26,8 @@ std::vector f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template , 1>; +// clang-format on + +bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = HostTensorDescriptor({G, NDoHoWo, CZYX}); + const auto out_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array image_g_n_c_wis_strides{}; + std::array gemm_g_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), gemm_g_m_k_strides); + copy(out_desc.GetStrides(), image_g_n_c_wis_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{1, 2}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto col2img = DeviceColToImgInstance{}; + auto invoker = col2img.MakeInvoker(); + auto argument = col2img.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!col2img.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_col2img with the specified compilation parameters does " + "not support this col2img problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_column_to_image = ck::tensor_operation::host:: + ReferenceColumnToImage(); + + auto ref_invoker = ref_column_to_image.MakeInvoker(); + + auto ref_argument = ref_column_to_image.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_column_to_image.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_col2img with the specified compilation parameters does " + "not support this col2img problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(out_device.mData.data()); + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunColumnToImageExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunColumnToImage(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunColumnToImageExample(argc, argv); } diff --git a/example/52_im2col_col2im/common.hpp b/example/52_im2col_col2im/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..61d30c4cb44abfcdbebffd88aebb27d9783f88c5 --- /dev/null +++ b/example/52_im2col_col2im/common.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp" + +template +using S = ck::Sequence; + +static inline constexpr ck::index_t NDimSpatial = 2; + +using FP32 = float; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +#define DefaultConvParams \ + ck::utils::conv::ConvParam \ + { \ + NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/example/52_im2col_col2im/image_to_column_f32.cpp b/example/52_im2col_col2im/image_to_column_f32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..140bdf8062d1ad91d02f84b4056d4b446011e632 --- /dev/null +++ b/example/52_im2col_col2im/image_to_column_f32.cpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = FP32; +using OutDataType = FP32; + +using ImLayout = ck::tensor_layout::convolution::GNHWC; +using ImageToColumnOp = ck::conv_tensor_rearrange_op::ImageToColumn; + +// clang-format off +using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl + //#####################| Num| ImLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + < NDimSpatial, ImLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>; +// clang-format on + +bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + const auto out_desc = HostTensorDescriptor({G, NDoHoWo, CZYX}); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array image_g_n_c_wis_strides{}; + std::array gemm_g_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), image_g_n_c_wis_strides); + copy(out_desc.GetStrides(), gemm_g_m_k_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto img2col = DeviceImgToColInstance{}; + auto invoker = img2col.MakeInvoker(); + auto argument = img2col.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!img2col.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_image_to_column = ck::tensor_operation::host:: + ReferenceImageToColumn(); + + auto ref_invoker = ref_image_to_column.MakeInvoker(); + + auto ref_argument = ref_image_to_column.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_image_to_column.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunImageToColumnExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunImageToColumn(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunImageToColumnExample(argc, argv); } diff --git a/example/53_layernorm2d_bwd/CMakeLists.txt b/example/53_layernorm2d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a58b1109f7c409529486ee0ab62cacea7696843f --- /dev/null +++ b/example/53_layernorm2d_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp) diff --git a/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b0a3e72ad7c57e8842f4d68459b3918ffea6ae1 --- /dev/null +++ b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; +using ComputeDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +// Layernorm: + +// Input shape +// dy: [M, N] +// x: [M, N] +// mean: [M, 1] +// inv_std: [M, 1] + +// Output shape +// dx: [M, N] +// dgamma: [1, N] +// dbeta: [1, N] + +// dgamma = reduce_sum(dy * (x - mean) * inv_std, axis=0) +// dbeta = reduce_sum(dy, axis=0) + +// [CAUSION] +// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant +// dimension, K is reduced dimension Hence, M in this example and +// DeviceNormalizationBwdGammaBetaImpl is different +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + +using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< + DYDataType, + XDataType, + MeanInvStdDataType, + ComputeDataType, + DGammaDataType, + DBetaDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 4, // MThreadSliceSize + 1, // KThreadSliceSize + false, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + false, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize + +int main() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 512; + + Tensor dy({M, N}); + Tensor x({M, N}); + Tensor gamma({N}); + Tensor mean({M}); + Tensor inv_std({M}); + + Tensor dgamma({N}); + Tensor dbeta({N}); + Tensor dx({M, N}); + + dy.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + mean.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta + auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; + auto gamma_beta_argument_ptr = + gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto gamma_beta_invoker_ptr = gamma_beta_device_instance.MakeInvokerPointer(); + gamma_beta_invoker_ptr->Run(gamma_beta_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_dgamma({N}); + Tensor host_dbeta({N}); + Tensor host_dx({M, N}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, {M, N}); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); + + pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/example/54_groupnorm_bwd/CMakeLists.txt b/example/54_groupnorm_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2cb103499c212bc03455464049ff6005d8f19878 --- /dev/null +++ b/example/54_groupnorm_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp) diff --git a/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6cf1b2ff91e6e607f54848f5bbd27fb28a39e8e7 --- /dev/null +++ b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; +using ComputeDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +// Grouprnorm +// kernel 1: M , K +// dy: N, H, W, G, C -> N * G, H * W * C +// x: N, H, W, G, C -> N * G, H * W * C +// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C +// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 +// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 + +// dx: N, H, W, G, C -> N * G, H * W * C + +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + +// kernel 2: M , K +// dy: N, H, W, G, C -> G * C, N * H * W +// x: N, H, W, G, C -> G * C, N * H * W +// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 +// rstd: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 + +// dgamma: 1, 1, 1, G, C -> G * C +// dbeta: 1, 1, 1, G, C -> G * C + +// reduced axis: 0, 1, 2 + +using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< + DYDataType, + XDataType, + MeanInvStdDataType, + ComputeDataType, + DGammaDataType, + DBetaDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterInvariant + 32, // ClusterReduce + 4, // SliceInvariant + 1, // SliceReduce + false, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + false, // IsXFastestDimReduced + 4, // XSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize + +int main() +{ + bool time_kernel = false; + + ck::index_t N = 16; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 32; + ck::index_t C = 64; + + Tensor dy({N, H, W, G, C}); + Tensor x({N, H, W, G, C}); + Tensor gamma({G, C}); + Tensor mean({N, G}); + Tensor inv_std({N, G}); + + Tensor dgamma({G, C}); + Tensor dbeta({G, C}); + Tensor dx({N, H, W, G, C}); + + dy.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + mean.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + std::vector dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + std::vector gammaStrides = {0, 0, 0, C, 1}; + std::vector meanStrides = {G, 0, 0, 1, 0}; + std::vector invStdStrides = {G, 0, 0, 1, 0}; + std::vector dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()}; + + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths + dyStrides, // dyStrides + xStrides, // xStrides + gammaStrides, // gammaStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + dxStrides, // dxStrides + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta + + auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; + auto gamma_beta_argument_ptr = + gamma_beta_device_instance.MakeArgumentPointer({N, H, W, G, C}, // inLengths + dyStrides, // dyStrides + xStrides, // xStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + {G, C}, // outLengths + {C, 1}, // dgammaStrides + {C, 1}, // dbetaStrides + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto gamma_beta_invoker_ptr = gamma_beta_device_instance.MakeInvokerPointer(); + gamma_beta_invoker_ptr->Run(gamma_beta_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_dgamma({G, C}); + Tensor host_dbeta({G, C}); + Tensor host_dx({N, H, W, G, C}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument( + dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, {N, H, W, G, C}); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); + + pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e49056a948ca6dca6970dcf65e0c3c49089fc88a --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -0,0 +1,7 @@ +add_custom_target(example_grouped_gemm_xdl_multi_abd) + +add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) +add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..742fd5547a97698f363816b6833e1621bce92d45 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a0_tensors; + std::vector> b_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a0_tensors.reserve(group_count); + b_tensors.reserve(group_count); + b0_tensors.reserve(group_count); + b1_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b1_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); + + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + + sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), + b0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), + b1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B1DataType)); + + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); + } + } + + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + c_host_tensors[i], + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..809c1a956cce8dcfd28623d56e3be86075acfd28 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; +using Scale = ck::tensor_operation::element_wise::Scale; +using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; + +using A0DataType = F16; +using A1DataType = F32; +using AsDataType = ck::Tuple; +using B0DataType = F16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using A0Layout = Row; +using A1Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = AddScale; +using BElementOp = PassThrough; + +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ck::half_t>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a0_tensors; + std::vector> a1_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a0_tensors.reserve(group_count); + a1_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + a1_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + a1_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 2; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + a1_tensors_device.emplace_back( + std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), + a1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A1DataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + {1, 1}, + {problem_size.stride_Bs[i]}, + {0}, + 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer(), + a1_tensors_device[i]->GetDeviceBuffer()}, + std::array{b_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i], + problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i]}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + constexpr float scale = 1.f; + auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); + } + } + + c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), + e_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + e_host_tensors[i], + PassThrough{}, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9e0d3f9ada23a9da91bbe1d7532a253630b276a --- /dev/null +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) +add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_xdl_fastgelu_bf16_i8 gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f3bba922f05487356ad1b204397fd76b91cff29 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95cf8f3674a08f7e8968730081b6739e473f3ac5 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93034a8b70ca014a2f80c3f4e4fff2cd15ba0147 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 1, + 2, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..07b9db46208451225a34eaa90b8914a271bed2e1 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyAddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 2; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer()}, + std::array{b1_device_buf.GetDeviceBuffer(), + d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{0, StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + +#if 0 + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } +#endif + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b8bd4cad2092b5fe67d10d62ea8db8aebb9a041 --- /dev/null +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b88e2482d66518b4dcff93a6fd0cd0bcfaae2d4 --- /dev/null +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F16; +using A1DataType = F32; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +struct Multiply +{ + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const + { + a = ck::type_convert(ck::type_convert(a0) * a1); + } +}; + +using AElementOp = Multiply; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< + NumDimM, + NumDimN, + NumDimK, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 1, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + float alpha = 1.0f; + float beta = 1.0f; + + // A0[M0, M1, K0, K1] + std::vector a0_ms_ks_lengths{30, 128, 32, 64}; + std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // A1[M1, K1] -> A1[M0, M1, K0, K1] + std::vector a1_ms_ks_lengths{30, 128, 32, 64}; + std::vector a1_ms_ks_strides{0, 64, 1, 0}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; + std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_ms_ks.mData.data()); + a1_device_buf.ToDevice(a1_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, + std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, + std::array, 1>{b_ns_ks_lengths}, + std::array, 1>{b_ns_ks_strides}, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_contraction with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + + for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) + { + for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) + { + a_element_op(a_ms_ks(m0, m1, k0, k1), + a0_ms_ks(m0, m1, k0, k1), + a1_ms_ks(m0, m1, k0, k1)); + } + } + } + } + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eaabccdf2a773f2b298b7059a0a1811618ac88e5 --- /dev/null +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using EDataType = F16; +using ComputeDataType = F8; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct Multiply +{ + __host__ __device__ constexpr void + operator()(ck::f8_t& a, const ck::f8_t& a0, const float& a1) const + { + a = ck::type_convert(ck::type_convert(a0) * a1); + } +}; + +using AElementOp = Multiply; +using BElementOp = Multiply; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< + NumDimM, + NumDimN, + NumDimK, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 1, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A0[M0, M1, K0, K1] + std::vector a0_ms_ks_lengths{30, 128, 32, 64}; + std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // A1[M1, K1] -> A1[M0, M1, K0, K1] + std::vector a1_ms_ks_lengths{30, 128, 32, 64}; + std::vector a1_ms_ks_strides{0, 64, 1, 0}; + // B0[N0, N1, K0, K1] + std::vector b0_ns_ks_lengths{32, 64, 32, 64}; + std::vector b0_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // B1[N0, N1, K0, K1] + std::vector b1_ns_ks_lengths{32, 64, 32, 64}; + std::vector b1_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; + std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; + + std::cout << "b0_ns_ks: " << b0_ns_ks.mDesc << std::endl; + std::cout << "b1_ns_ks: " << b1_ns_ks.mDesc << std::endl; + + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_ms_ks.mData.data()); + a1_device_buf.ToDevice(a1_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, + std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, + std::array, 2>{b0_ns_ks_lengths, b1_ns_ks_lengths}, + std::array, 2>{b0_ns_ks_strides, b1_ns_ks_strides}, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + PassThrough{}); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_contraction with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + + for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) + { + for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) + { + a_element_op(a_ms_ks(m0, m1, k0, k1), + a0_ms_ks(m0, m1, k0, k1), + a1_ms_ks(m0, m1, k0, k1)); + } + } + } + } + + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + + for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) + { + for(size_t n1 = 0; n1 < b_ns_ks.mDesc.GetLengths()[1]; ++n1) + { + for(size_t k0 = 0; k0 < b_ns_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < b_ns_ks.mDesc.GetLengths()[3]; ++k1) + { + b_element_op(b_ns_ks(n0, n1, k0, k1), + b0_ns_ks(n0, n1, k0, k1), + b1_ns_ks(n0, n1, k0, k1)); + } + } + } + } + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = ref_op.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..79fafed4eb6d3843eacf5910b236c1c541470f13 --- /dev/null +++ b/example/62_convnd_activ/CMakeLists.txt @@ -0,0 +1,16 @@ +add_subdirectory(binary) +add_subdirectory(convinvscale) +add_subdirectory(convscale) +add_subdirectory(convscale_relu) +add_subdirectory(convscale_add) +add_subdirectory(convscale_reduce) +add_subdirectory(multi_AB) +add_subdirectory(unary) +add_subdirectory(dynamic_unary) + +add_custom_target(example_convnd_activ_xdl) +# ScaleAdd ScaleAdd Relu +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9d90cdd244e75d1695b35c0bd367e457816a200b --- /dev/null +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -0,0 +1,15 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_binary_xdl) + # Bilinear residual + add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16) + add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16) + add_example_executable(example_convnd_bwd_weight_xdl_bilinear_residual_fp16 convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_weight_xdl_bilinear_residual_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5bddf2302ea8134d1c5dc361910171d5e44263d --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceGroupedConvNDBwdDataInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + NDimSpatial, + OutLayout, + WeiLayout, + ck::Tuple, + InLayout, + OutDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp, + ConvSpec, // ConvForwardSpecialization + true, + true, + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 2, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdDataInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 1; + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + in_host.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in_host.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // Initialize based on out_host + Tensor in_device(in_host); + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + in_device_buf.ToDevice(in_device.mData.data()); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use output as D + const std::array ds = {in_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{e_g_n_c_wis_lengths}, + std::array, NumDs>{e_g_n_c_wis_strides}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetInputByte() / sizeof(InDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {in_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device.mData, in_host.mData); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa3edc5adc00e7eda0ab44004a8a02591dc7a1b7 --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::Bilinear; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +template +using DeviceGroupedConvNDBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, // InLayout + WeiLayout, // WeiLayout + OutLayout, // OutLayout + ck::Tuple, // DsLayout + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + ck::Tuple, // DsLayout + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdWeightInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t split_k = 1; + constexpr ck::index_t NumDs = 1; + Tensor in(in_g_n_c_wis_desc); + Tensor wei_host(wei_g_k_c_xs_desc); + Tensor out(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei_host.mDesc << std::endl; + std::cout << "out: " << out.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_host.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_host.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + // Initialize based on wei_host + Tensor wei_device(wei_host); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei_device.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + std::array b_g_n_c_wis_lengths{}; + std::array b_g_n_c_wis_strides{}; + std::array e_g_k_c_xs_lengths{}; + std::array e_g_k_c_xs_strides{}; + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), b_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), b_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), e_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), e_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use weight as D + const std::array ds = {wei_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + ds, + b_g_n_c_wis_lengths, + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + std::array, NumDs>{e_g_k_c_xs_lengths}, + std::array, NumDs>{e_g_k_c_xs_strides}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + + DeviceMem workspace_buf(argument.GetWorkspaceSizeBytes()); + conv.SetWorkSpacePointer(&argument, workspace_buf.GetDeviceBuffer()); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetOutputByte() / sizeof(WeiDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {wei_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdWeight{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei_host, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + wei_device_buf.FromDevice(wei_device.mData.data()); + + return ck::utils::check_err(wei_device, wei_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae1ebcb2cd64d0da533978995863e2f9ce80e6a5 --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 1; + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + out_host.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + out_host.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + // Initialize based on out_host + Tensor out_device(out_host); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out_device.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use output as D + const std::array ds = {out_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{e_g_n_k_wos_lengths}, + std::array, NumDs>{e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {out_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..07f42075bd0a8a211f1c72cd7e60c6673469f846 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -0,0 +1,10 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convinvscale) + add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp b/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b2ebf84849e055d25da39759c5939387eaab27e --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fbdfc72063dc4210c3ee429dd3e369398dc748ed --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convinvscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvInvscale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convinvscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc b/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..797146060216ead74f0b33c61e7c236d9bbb5772 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d101fd59bd9281a81f46c7b28868f3ad07626de0 --- /dev/null +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using BiasLayout = ck::tensor_layout::convolution::G_K; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 2; + const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0]; + const ck::index_t K = out_g_n_k_wos_desc.GetLengths()[2]; + + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_g_k_lengths; + std::array bias_g_k_strides; + // Fill other lenghts than G,K with 1 and strides with 0 + bias_g_k_lengths.fill(1); + bias_g_k_strides.fill(0); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + + // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + std::array, NumDs> d_tensors = {Tensor(out_g_n_k_wos_desc), + Tensor(broadcasted_bias_desc)}; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + std::cout << "z_tensor: " << d_tensors[0].mDesc << std::endl; + std::cout << "bias_tensor: " << d_tensors[1].mDesc << std::endl; + + // Make sure that we allocated only G * K values for bias + assert(static_cast(d_tensors[1].mData.size()) == G * K); + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem z_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize()); + DeviceMem bias_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + z_buf.ToDevice(d_tensors[0].mData.data()); + bias_buf.ToDevice(d_tensors[1].mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + const std::array ds = {z_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{ + e_g_n_k_wos_lengths, bias_g_k_lengths}, + std::array, NumDs>{ + e_g_n_k_wos_strides, bias_g_k_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + G * K + + conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + G * K * sizeof(OutDataType) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f784655cc57da88f5e67eeb5dbcdf440b8b7d35f --- /dev/null +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 2; + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + std::array, NumDs> d_tensors = {Tensor(out_g_n_k_wos_desc), + Tensor(out_g_n_k_wos_desc)}; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem d0_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize()); + DeviceMem d1_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + d0_buf.ToDevice(d_tensors[0].mData.data()); + d1_buf.ToDevice(d_tensors[1].mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + const std::array ds = {d0_buf.GetDeviceBuffer(), d1_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{ + e_g_n_k_wos_lengths, e_g_n_k_wos_lengths}, + std::array, NumDs>{ + e_g_n_k_wos_strides, e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 2 * conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + 2 * conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9264da24a69896f309796154084098d61d5e43ac --- /dev/null +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -0,0 +1,20 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 ) + + add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8) + + add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) + + add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..978221f8e1093485cec0e98b140d801c1ca01335 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1c8c3a57f0cc5517bd3621c883fb64870bfab4c --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8590d0620f4d280f51ed0d4bd8672a60da435624 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7d69ccffc10056b6b582d6c7124e3d5d5928da2 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab59e08a800435657517ecc6a6e1398820a39247 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc b/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..797146060216ead74f0b33c61e7c236d9bbb5772 --- /dev/null +++ b/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_add/CMakeLists.txt b/example/62_convnd_activ/convscale_add/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..40cfd74aa4570b5fd3bad0f8ffb96824b97996e4 --- /dev/null +++ b/example/62_convnd_activ/convscale_add/CMakeLists.txt @@ -0,0 +1,11 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale_add) + add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8 ) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale_add/convnd_fwd_convscale_add_common.hpp b/example/62_convnd_activ/convscale_add/convnd_fwd_convscale_add_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b588d710878623d18fbc5aa57b884eb9979b1f70 --- /dev/null +++ b/example/62_convnd_activ/convscale_add/convnd_fwd_convscale_add_common.hpp @@ -0,0 +1,315 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleAdd = ck::tensor_operation::element_wise::ConvScaleAdd; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(out_g_n_k_wos_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + bias.GenerateTensorValue(GeneratorTensor_2{-3, 3}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + bias.GenerateTensorValue(GeneratorTensor_3{-3.0, 3.0}); + break; + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(DsDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d_g_n_k_wos_lengths{}; + std::array d_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::cout << std::endl; + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{{d_g_n_k_wos_lengths}}, + std::array, 1>{{d_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 element-wise add + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = + conv_param.GetInputByte() + conv_param.GetWeightByte() + + sizeof(float) + sizeof(float) + sizeof(float) + conv_param.GetOutputByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach( + [&](auto&, auto idx) { out_element_op(out_host(idx), c(idx), bias(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp b/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f592b2c54c437a843357794e20168e4f343f7d7 --- /dev/null +++ b/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/tuple.hpp" +#include "convnd_fwd_convscale_add_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleAdd; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_add_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale_add/run_convnd_fwd_convscale_add_example.inc b/example/62_convnd_activ/convscale_add/run_convnd_fwd_convscale_add_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..69faec90b59b6de7ea803b2cd992849028def09f --- /dev/null +++ b/example/62_convnd_activ/convscale_add/run_convnd_fwd_convscale_add_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ff9020a70706ef1eb1d142211ff2052a486db266 --- /dev/null +++ b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt @@ -0,0 +1,14 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale_reduce) + add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8) + + add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6940c20695782e82b362f116530283e878365b34 --- /dev/null +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp @@ -0,0 +1,502 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/utility/type.hpp" + +namespace ew = ck::tensor_operation::element_wise; + +using PassThrough = ew::PassThrough; +using ConvScaleRelu = ew::UnaryCombinedOp; +using ConvScale = ew::UnaryCombinedOp; + +using UnaryScaleConvert = ew::Scale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor host_conv(out_g_n_k_wos_desc); + Tensor device_conv(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 11: // used for debugging + in.GenerateTensorValue(GeneratorTensor_1{1}); + wei.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem conv_device_buf(conv_param.GetOutputByte()); + DeviceMem out_device_buf(conv_param.GetOutputByte()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::cout << std::endl; + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // convolution elementwise operation + auto conv_element_op = ConvElementOp{ew::Scale{scale_in}, ew::Scale{scale_wei}, {}}; + auto scale_convert = UnaryScaleConvert{scale_out}; // elementwise scale and type cast + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto conv_invoker = conv.MakeInvoker(); + auto conv_argument = + conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + conv_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + conv_element_op); + + if(!conv.IsSupportedArgument(conv_argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + std::string kernels = conv.GetTypeString(); + + float avg_time = conv_invoker.Run(conv_argument, StreamConfig{nullptr, time_kernel}); + + using DeviceElementwiseScale = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryScaleConvert, // UnaryScaleConvert + NDimSpatial + 3, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + + auto device_ew_scale = DeviceElementwiseScale{}; + auto scale_invoker = device_ew_scale.MakeInvoker(); + auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths, + {e_g_n_k_wos_strides}, + {e_g_n_k_wos_strides}, + {conv_device_buf.GetDeviceBuffer()}, + {out_device_buf.GetDeviceBuffer()}, + scale_convert); + + if(!device_ew_scale.IsSupportedArgument(scale_argument)) + { + throw std::runtime_error( + "wrong! DeviceElementwiseScale with the specified compilation parameters does " + "not support this problem"); + } + + kernels += std::string("\n\t\t ") + device_ew_scale.GetTypeString(); + + avg_time += scale_invoker.Run(scale_argument, StreamConfig{nullptr, time_kernel}); + + constexpr auto ReduceOpId = ck::ReduceTensorOp::AMAX; + using ReduceOperation = typename ck::reduce_binary_operator::opType; + using InElementwiseOperation = + typename ck::reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename ck::reduce_unary_operator::AccElementwiseOperation; + using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceMultiBlock; // OutDstVectorSize + + std::vector outLengths = {1}; + Tensor amax_host(outLengths); + Tensor amax_from_device(outLengths); + auto amax_host_strides = amax_host.mDesc.GetStrides(); + + std::array reduce_dims; + std::iota(reduce_dims.begin(), reduce_dims.end(), 0); // 0,..., NDimSpatial+3-1 + + std::array reduce_out_lengths{1}; + std::array reduce_out_strides{static_cast(amax_host_strides[0])}; + + DeviceMem amax_device(sizeof(ConvOutDataType) * amax_host.mDesc.GetElementSpaceSize()); + DeviceMem index_device; + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + std::tie(in_elementwise_op, acc_elementwise_op) = + ck::reduce_unary_operator::GetElementwiseOperator( + static_cast(host_conv.mDesc.GetElementSize())); + + // Hack convolution output strides for reduction as kernel expects stride 1 for the last + // dimension. It only works because the reduction is done on the whole tensor and result is + // independent of the order of elements. + std::array reduction_strides{}; + copy(HostTensorDescriptor(e_g_n_k_wos_lengths).GetStrides(), reduction_strides); + + auto device_reduce = DeviceReduceInstance{}; + auto reduce_invoker = device_reduce.MakeInvokerPointer(); + auto reduce_argument = device_reduce.MakeArgumentPointer(e_g_n_k_wos_lengths, + reduction_strides, + reduce_out_lengths, + reduce_out_strides, + reduce_dims, + 1.0, + 0.0, + conv_device_buf.GetDeviceBuffer(), + nullptr, + amax_device.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + + if(!device_reduce.IsSupportedArgument(reduce_argument.get())) + { + throw std::runtime_error( + "wrong! DeviceReduceInstance with the specified compilation parameters does " + "not support this runtime parameters!"); + }; + + kernels += std::string("\n\t\t ") + device_reduce.GetTypeString(); + + float reduce_time = + reduce_invoker->Run(reduce_argument.get(), StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + std::cout << "\nReduce time: " << reduce_time << " ms" << std::endl; + + avg_time += reduce_time; + + std::size_t flop = conv_param.GetFlops(); // convolution FLOPs + auto conv_out_elems = host_conv.GetElementSize(); // number of elements in conv result tensor + + // 3 element-wise scale multipliers + 1 AMAX + std::size_t elementwise_ops = 3 + 1; + if constexpr(ck::is_same_v) + { + elementwise_ops += 1; // +1 element-wise relu + } + + flop += elementwise_ops * conv_out_elems; + + // convolution + elementwise scaling (in + wei + output byte count) + std::size_t num_btype = conv_param.GetByte(); + num_btype += sizeof(float) + sizeof(float); // + 2 scales + + // elementwise scaling + F8 conversion + num_btype += conv_param.GetOutputByte() + sizeof(float) + + conv_param.GetOutputByte(); + + // AMAX + num_btype += conv_param.GetOutputByte() + sizeof(float); + + if(time_kernel) + { + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << std::endl; + } + + std::cout << "\nKernels: " << kernels << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + host_conv, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + conv_element_op); + + ref_invoker.Run(ref_argument); + + conv_device_buf.FromDevice(device_conv.mData.data()); + + out_device_buf.FromDevice(out_device.mData.data()); + + out_host.ForEach([&](auto&, auto idx) { scale_convert(out_host(idx), host_conv(idx)); }); + + std::cout << "\nComparing output to reference: " << std::endl; + auto tight_tol_check = ck::utils::check_err(out_device, out_host, "Error: "); + if(!tight_tol_check) + { + std::cout << "\n\tRecompare applying tolerances...\n"; + std::cout << "\t\trtol = " << get_rtol() << std::endl; + std::cout << "\t\tatol = " << get_atol() << std::endl; + auto loose_tol_check = ck::utils::check_err(out_device, + out_host, + "Error: incorrect convolution results!", + get_rtol(), + get_atol()); + if(!loose_tol_check) + { + return false; + } + } + std::cout << "Success!" << std::endl; + + /// Verify AMAX + + using RefReduceInstance = + ck::tensor_operation::host::ReferenceReduce; + + auto ref_reduce = RefReduceInstance{}; + auto ref_reduce_invoker = ref_reduce.MakeInvokerPointer(); + auto ref_reduce_argument = ref_reduce.MakeArgumentPointer(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + reduce_out_lengths, + reduce_out_strides, + reduce_dims, + 1.0, + 0.0, + host_conv.mData.data(), + nullptr, + amax_host.mData.data(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + + if(!ref_reduce.IsSupportedArgument(ref_reduce_argument.get())) + { + throw std::runtime_error( + "wrong! RefReduceInstance with the specified compilation parameters does " + "not support this runtime parameters!"); + }; + + ref_reduce_invoker->Run(ref_reduce_argument.get()); + + amax_device.FromDevice(amax_from_device.mData.data()); + + std::cout << "\namax: " << amax_from_device.mData[0] << std::endl; + std::cout << "amax_ref: " << amax_host.mData[0] << std::endl; + + return ck::utils::check_err(amax_from_device, amax_host, "Error: incorrect AMAX results!"); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8b4fdbeadc0b90adc2b518309c16c11acde16ef --- /dev/null +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_reduce_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using ConvOutDataType = float; // data type of convolution result +using OutDataType = ck::f8_t; // data type of final result +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + ConvOutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df6bf7bd5c109b68f2463d0ab70edc703a92f171 --- /dev/null +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_reduce_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using ConvOutDataType = float; // data type of convolution result +using OutDataType = ck::f8_t; // data type of final result +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + ConvOutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc b/example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..24775f21b53e9bf529a804c973bd0c7d06f5b775 --- /dev/null +++ b/example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd< + ndim_spatial_value, + InDataType, + WeiDataType, + ConvOutDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceGroupedConvNDFwdInstance>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..95589cedcb406d55b9e91de6d7cdb1a06a87a01d --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -0,0 +1,11 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale_relu) + add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8 ) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2dacc20524d9b181251dfb1143108e7edba44bd --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::cout << std::endl; + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 element-wise relu + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..360349e7eca7526bd9e1af814213c540e116d51c --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_relu_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_relu_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc b/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..797146060216ead74f0b33c61e7c236d9bbb5772 --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..23f07439a56f21090204248036f95f31de73d6a0 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt @@ -0,0 +1,45 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_dynamic_unary_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_dynamic_tanh_fp16 convnd_fwd_xdl_dynamic_tanh_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_dynamic_relu_fp16 convnd_fwd_xdl_dynamic_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_dynamic_softrelu_fp16 convnd_fwd_xdl_dynamic_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_dynamic_abs_fp16 convnd_fwd_xdl_dynamic_abs_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_dynamic_pow_fp16 convnd_fwd_xdl_dynamic_pow_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_dynamic_clippedrelu_fp16 convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_dynamic_leakyrelu_fp16 convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_dynamic_elu_fp16 convnd_fwd_xdl_dynamic_elu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_elu_fp16) + # Swish + add_example_executable(example_convnd_fwd_xdl_dynamic_swish_fp16 convnd_fwd_xdl_dynamic_swish_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_swish_fp16) + # PassThrough + add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fwd_xdl_dynamic_passthrough_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16) + # Logistic + add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed31be19eeaaf1096a50334a60afd0310c7c0f17 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using DynamicElementOp = ck::tensor_operation::element_wise::DynamicUnaryOp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGroupedConvNDActivInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + DynamicElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_abs_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_abs_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fa455c62e9eb6d11fb28015ff0e74d1bafaa81e --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_abs_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::UnaryAbs out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..239a21525b7efa03813af627aef98d6c0aee2877 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::ClippedRelu out_element_op(0.f, 1.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_elu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_elu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23a094af7078dfaed23f55b0a69369304274a40b --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_elu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Elu out_element_op(2.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe4b80a68170ef832dc5fd5d590f07f086b980bb --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::LeakyRelu out_element_op(0.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_logistic_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_logistic_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..756c07ed8535fbd669fe2ebe51d5f0a7b1e6d5d6 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_logistic_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Logistic out_element_op(1.0f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_passthrough_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_passthrough_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6588ec50446ffae424f2f1b6a0794e740d28c6d1 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_passthrough_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::PassThrough out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_pow_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_pow_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..90f00a166aaae3dc9ebf203cad6ee995df679d2d --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_pow_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Power out_element_op(4.f, 1.f, 2.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_relu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..830297cb56da3f5e29c2b7d821ef84c8a9381a7b --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_relu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Relu out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b143b4a4ebf710baa54829560b545f2147b0206e --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Sigmoid out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_softrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_softrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83ba0f7f8c4eaf91bda03f2fa666f89a98c5fb45 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_softrelu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::SoftRelu out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_swish_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_swish_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e862d1120a8f425fcc6cdc1acc492b1b3144d72f --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_swish_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Swish out_element_op(1.0f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_tanh_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_tanh_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a91fc7ce3031a094ee14dc664302d2fa83fcf149 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_tanh_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::TanH out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c89c82d384c68a9676025421d04d542e28725710 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -0,0 +1,17 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_multi_ab_xdl) + # ScaleAdd on A and B + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 conv_fwd_xdl_scaleadd_ab_fp16.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 conv_fwd_xdl_scaleadd_ab_fp32.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp32) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scaleadd_ab_bf16.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7ceee03b829a943a657a1c2090f1a9fc6bd3c44 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::bhalf_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..08d8a89669cc85b79882d5f8a5339eba6f4c76ad --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::half_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bef9980b3e3b1604efec00d1dcef50622b9dee73 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = float; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b132b91213e9eb0237a21b63ba9e80f49e3c400 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = int8_t; +using AccDataType = int32_t; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2626843ed4b933d66c27213c324e90fdd75a10f6 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDMultiABFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataTypes, + WeiDataTypes, + AccDataType, + DataType, + ck::Tuple<>, + DataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +namespace { +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + Tensor in(in_g_n_c_wis_desc); + Tensor in_bias(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor wei_bias(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + in_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + wei_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem in_bias_device_buf(sizeof(InDataType) * in_bias.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem wei_bias_device_buf(sizeof(WeiDataType) * wei_bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + in_bias_device_buf.ToDevice(in_bias.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + wei_bias_device_buf.ToDevice(wei_bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + std::array as{in_device_buf.GetDeviceBuffer(), + in_bias_device_buf.GetDeviceBuffer()}; + std::array bs{wei_device_buf.GetDeviceBuffer(), + wei_bias_device_buf.GetDeviceBuffer()}; + std::array ds{}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(as, + bs, + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + + 2 * conv_param.GetOutputByte() / sizeof(InDataType) + + 2 * conv_param.GetOutputByte() / sizeof(WeiDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetInputByte() + + conv_param.GetWeightByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + const std::array, NumAs - 1> elementwise_a_tensors = {in_bias}; + const std::array, NumBs - 1> elementwise_b_tensors = {wei_bias}; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + elementwise_a_tensors, + elementwise_b_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace diff --git a/example/62_convnd_activ/run_convnd_activ_dynamic_example.inc b/example/62_convnd_activ/run_convnd_activ_dynamic_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..4e90cf936653e662cb73901dcbc37a0d6ec20ef2 --- /dev/null +++ b/example/62_convnd_activ/run_convnd_activ_dynamic_example.inc @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +bool run_convnd_example(int argc, char* argv[], const OutElementOp& out_element_op) +{ + print_helper_msg(); + + bool do_verification = true; + // Use floats for SoftRelu by default to avoid overflow after e^x. + int init_method = + std::is_same_v ? 2 : 1; + bool time_kernel = false; + + // Following shapes are selected to avoid overflow. Expect inf in case of + // size increase for some elementwise ops. + ck::utils::conv::ConvParam conv_param{ + 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = [&]() { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + if(conv_param.num_dim_spatial_ == 3) + { + return run(); + } + + return false; +} diff --git a/example/62_convnd_activ/run_convnd_activ_example.inc b/example/62_convnd_activ/run_convnd_activ_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..5a402e41cdeb1ed8a7655b2a0580391ed5cc9226 --- /dev/null +++ b/example/62_convnd_activ/run_convnd_activ_example.inc @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +bool run_convnd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + // Use floats for SoftRelu by default to avoid overflow after e^x. + int init_method = + std::is_same_v ? 2 : 1; + bool time_kernel = false; + + // Following shapes are selected to avoid overflow. Expect inf in case of + // size increase for some elementwise ops. + ck::utils::conv::ConvParam conv_param{ + 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto run = [&]() { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + if(conv_param.num_dim_spatial_ == 3) + { + return run(); + } + + return false; +} diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3470e9b9456f7361e1a05c07e99b00fe90990a31 --- /dev/null +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -0,0 +1,45 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_unary_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16) + # Swish + add_example_executable(example_convnd_fwd_xdl_swish_fp16 convnd_fwd_xdl_swish_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_swish_fp16) + # PassThrough + add_example_executable(example_convnd_fwd_xdl_passthrough_fp16 convnd_fwd_xdl_passthrough_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_passthrough_fp16) + # Logistic + add_example_executable(example_convnd_fwd_xdl_logistic_fp16 convnd_fwd_xdl_logistic_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_logistic_fp16) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp b/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4669465bf440077577202a54807e32763bc04980 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e621c3b15e3178969ef2fc3fb44203c1efb6032f --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb26fb292cd8282951864b62c46846a9a0c7e545 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1aa4d5d4fa224b5857c19491317bce14d24d6dea --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Elu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..659c36ec8ddee15844b4fd2bf987c55eb725e2ca --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86811c2e969bf47c68b024bcea6f504a4b335e7f --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Logistic; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7167c4a84a5c9eae3084fef0c5af97becbe67348 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5efa0f8f9c76ac6062a699166fe7c2d8c104e192 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Power; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84b7c598eecac8e4ddf3964172921c2822f07395 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Relu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53e06d387f93127448df44dab9e4c8090805e434 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Sigmoid; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a2d76da4e214f44fc20da929a0ceea52cb6ceba3 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::SoftRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65a2a5023eb26a542a75e4e4c0dde83824e139a3 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Swish; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d60c005e0937cc9120cfddd67c52cbfa39ab1898 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::TanH; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/63_layernorm4d_fwd/CMakeLists.txt b/example/63_layernorm4d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f8c679ab82c8edaa3a96f2dcf5afded77fabe8d --- /dev/null +++ b/example/63_layernorm4d_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_layernorm4d_fwd_fp16 layernorm4d_fwd_fp16.cpp) +add_example_executable(example_layernorm4d_fwd_splitk_fp16 layernorm4d_fwd_splitk_fp16.cpp) diff --git a/example/63_layernorm4d_fwd/common.hpp b/example/63_layernorm4d_fwd/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6c7b99f89f9930cdcb84bda674f96ab3a1237241 --- /dev/null +++ b/example/63_layernorm4d_fwd/common.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/63_layernorm4d_fwd/layernorm4d_fwd_fp16.cpp b/example/63_layernorm4d_fwd/layernorm4d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..659cc3554d2a7f2c28f3d55bbb8a6ac65fe1c173 --- /dev/null +++ b/example/63_layernorm4d_fwd/layernorm4d_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector +#include "run_layernorm4d_fwd_example.inc" + +int main() { return run_layernorm4d_fwd_example(); } diff --git a/example/63_layernorm4d_fwd/layernorm4d_fwd_splitk_fp16.cpp b/example/63_layernorm4d_fwd/layernorm4d_fwd_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..415635a0e32d2c0777cbe87d0490863fd4a80e89 --- /dev/null +++ b/example/63_layernorm4d_fwd/layernorm4d_fwd_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< + XDataType, + GammaDataType, + BetaDataType, + ComputeDataType, + YDataType, + SaveMeanInvStdDataType, + PassThrough, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterM + 32, // ClusterK + 1, // SliceM + 8, // SliceK + 1, // XYVectorDim (0=M, 1=K) + 8, // XScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 8, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 8, // BetaScalarPerVector + 8, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector + +#include "run_layernorm4d_fwd_example.inc" + +int main() { return run_layernorm4d_fwd_example(); } diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..1a0b558e2ca02ec61be7c3988b57e3dd93fb3040 --- /dev/null +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +int run_layernorm4d_fwd_example() +{ + bool time_kernel = false; + + ck::index_t N = 256; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t C = 8; + + Tensor x({N, H, W, C}); + Tensor gamma({H, W, C}); + Tensor beta({H, W, C}); + Tensor y({N, H, W, C}); + Tensor save_mean({N}); + Tensor save_inv_std({N}); + + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {N, H, W, C}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, W * C, C, 1}, + {0, W * C, C, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + {1, 2, 3}, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_y({N, H, W, C}); + Tensor host_save_mean({N}); + Tensor host_save_inv_std({N}); + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {N, H, W, C}, + {1, 2, 3}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif + } + + return (pass ? 0 : 1); +} diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b3c77b3bd7715dfa8de02221a3568590a6bfde28 --- /dev/null +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -0,0 +1,3 @@ +add_custom_target(example_fpAintB_gemm_wmma) +add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) +add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) diff --git a/example/64_fpAintB_gemm/common.hpp b/example/64_fpAintB_gemm/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4fb4c41d05692743f3130d13a669dd9312eaea42 --- /dev/null +++ b/example/64_fpAintB_gemm/common.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp" + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +struct UnsignedWeightPreprocessor +{ +}; + +template <> +struct UnsignedWeightPreprocessor +{ + using UnsignedWeight = Tensor; + using SignedWeight = Tensor; + static UnsignedWeight convert(SignedWeight const& Input) + { + + UnsignedWeight Output = Input.template CopyAsType(); + + auto f_kn = [&](auto k, auto n) { + const uint8_t adder = 128; + int8_t v_signed_weight; + uint8_t v_unsigned_weight; + + ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n)); + v_unsigned_weight = ck::type_convert(v_signed_weight) + adder; + Output(k, n) = v_unsigned_weight; + }; + + make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return Output; + } + + UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); } +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} diff --git a/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9dc97fecd86f4ee641e37853f4fba9557ffc2715 --- /dev/null +++ b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp" + +// Implementation follows the paper: +// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run: +// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022. +// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to +// unsigned. + +// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType +// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType + +// TODO: Current implementation consume more VGPR than expected. + +using ADataType = ck::half_t; +using QuantDataType = int8_t; +using BDataType = uint8_t; +using ScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + ScaleDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/64_fpAintB_gemm/run_gemm_example.inc b/example/64_fpAintB_gemm/run_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..dc2bdc18f01891ad83f3873ec6b86c67a8e419b5 --- /dev/null +++ b/example/64_fpAintB_gemm/run_gemm_example.inc @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + // assume scale tensor is [1, n] + Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + + switch(config.init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(scale_k_n); + break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); + } + + UnsignedWeightPreprocessor preprocessor; + Tensor b_k_n = preprocessor(quant_b_k_n); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + scale_k_n_device_buf.ToDevice(scale_k_n.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(scale_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + quant_b_k_n, + scale_k_n, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#endif + } + + return true; +} + +bool run_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..55c884246ff397dc56c31015bf0b168271eb124e --- /dev/null +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) +add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) \ No newline at end of file diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..580f38a79fc1568c2f2985443292af4366da177b --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct AddAdd +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c + d0 + d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb4f60764e1e0cf1522ebe30095b7f4a8bf7362b --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using B0DataType = FP8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR + ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2568754648d6a04a3066921c0f2ea2e3ee967962 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 128; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AM, + A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + +#if 1 + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } +#endif + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + Tensor a_m_k({M, K}); + Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + +#if 1 + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } +#endif + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb1642bba544fba26f30c5c048234aa96de453dd --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int; +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = I8; +using B0DataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR + ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + do_verification = false; + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/66_complex_contraction_bilinear/CMakeLists.txt b/example/66_complex_contraction_bilinear/CMakeLists.txt new file mode 100755 index 0000000000000000000000000000000000000000..c417caf8e7b556d54fe487064ab119b56ceec6f4 --- /dev/null +++ b/example/66_complex_contraction_bilinear/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp) + diff --git a/example/66_complex_contraction_bilinear/README.md b/example/66_complex_contraction_bilinear/README.md new file mode 100755 index 0000000000000000000000000000000000000000..04d92da0d2d3e3f3874b3865ad563d799f172890 --- /dev/null +++ b/example/66_complex_contraction_bilinear/README.md @@ -0,0 +1,11 @@ +# Instructions for ```example_complex_contraction_bilinear_xdl_fp32``` + +## Run +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_contraction_bilinear_xdl_fp32 1 1 1 +``` + + diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp new file mode 100644 index 0000000000000000000000000000000000000000..480ca5a0af67bcf8db871fcd79e795497901c5ab --- /dev/null +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using F64 = double; + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Generic instances for fp32, fp16 and bf16 data types. +template +// clang-format off +using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +// Fp64 instances. +template +// clang-format off +using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp new file mode 100755 index 0000000000000000000000000000000000000000..619279c47b263885c5ec223c6ad9952ccb1b298a --- /dev/null +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_complex_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); } diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp new file mode 100755 index 0000000000000000000000000000000000000000..f3528d09019e25feae9dcab8e47d92951bff0695 --- /dev/null +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; +using ComputeDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_complex_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); } diff --git a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..82ac0a15e1d9042221e69d64c87d80ab079508b1 --- /dev/null +++ b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_complex_contraction_bilinear_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + // For Real Part of Complex Tensor + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + + // For Imaginary Part of Complex Tensor + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + + // Intermediate E tensor Definition + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; + std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; + std::cout << "d_ms_ns_re: " << d_ms_ns_re.mDesc << std::endl; + std::cout << "e_ms_ns_re: " << e_ms_ns_host_result_re.mDesc << std::endl; + + std::cout << "a_ms_ks_img: " << a_ms_ks_img.mDesc << std::endl; + std::cout << "b_ns_ks_img: " << b_ns_ks_img.mDesc << std::endl; + std::cout << "d_ms_ns_img: " << d_ms_ns_img.mDesc << std::endl; + std::cout << "e_ms_ns_img: " << e_ms_ns_host_result_img.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + + default: + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + break; + } + + DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_re(sizeof(EDataType) * + e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + + DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img(sizeof(EDataType) * + e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + + // Intermediate Value For E Real and Img + DeviceMem e_device_buf_re1(sizeof(EDataType) * + e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img1(sizeof(EDataType) * + e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + + a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); + b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); + d_device_buf_re.ToDevice(d_ms_ns_re.mData.data()); + + a_device_buf_img.ToDevice(a_ms_ks_img.mData.data()); + b_device_buf_img.ToDevice(b_ns_ks_img.mData.data()); + d_device_buf_img.ToDevice(d_ms_ns_img.mData.data()); + + // set zero + e_device_buf_re.SetZero(); + e_device_buf_img.SetZero(); + + // set zero for intermediate values + e_device_buf_re1.SetZero(); + e_device_buf_img1.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + // For real Intermediate Value re_1 + + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument_re1 = + op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{d_device_buf_re.GetDeviceBuffer()}, + e_device_buf_re1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_re1)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); + + alpha = -1.f; + beta = 1.f; + + a_element_op = AElementOp{}; + b_element_op = BElementOp{}; + cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + // For real Intermediate Value re_2 + // auto op = DeviceOpInstance{}; + // auto invoker = op.MakeInvoker(); + auto argument_re2 = + op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{e_device_buf_re1.GetDeviceBuffer()}, + e_device_buf_re.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_re2)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); + + alpha = 1.f; + beta = 1.f; + + a_element_op = AElementOp{}; + b_element_op = BElementOp{}; + cde_element_op = CDEElementOp{alpha, beta}; + + auto argument_img1 = + op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{d_device_buf_img.GetDeviceBuffer()}, + e_device_buf_img1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_img1)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel}); + + alpha = 1.f; + beta = 1.f; + + auto argument_img2 = + op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{e_device_buf_img1.GetDeviceBuffer()}, + e_device_buf_img.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_img2)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K * 2; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; + + float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data()); + e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); + + auto isRealOk = 0; + auto isImgOk = 0; + + if(do_verification) + { + // Real Part Verification + Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument_re = ref_op.MakeArgument( + a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_re); + + alpha = 1.f; + beta = 1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1), + c_ms_ns_host_result_re(m0, m1, n0, n1), + d_ms_ns_re(m0, m1, n0, n1)); + } + } + } + } + + alpha = 1.f; + beta = -1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + auto ref_argument_re1 = ref_op.MakeArgument( + a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_re1); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1), + e_ms_ns_host_result_re(m0, m1, n0, n1), + c_ms_ns_host_result_re1(m0, m1, n0, n1)); + } + } + } + } + + isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + + // Img Part Verification + Tensor c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + + auto ref_argument_img = ref_op.MakeArgument( + a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_img); + + alpha = 1.f; + beta = 1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1), + c_ms_ns_host_result_img(m0, m1, n0, n1), + d_ms_ns_img(m0, m1, n0, n1)); + } + } + } + } + + auto ref_argument_img1 = ref_op.MakeArgument( + a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_img1); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1), + e_ms_ns_host_result_img(m0, m1, n0, n1), + c_ms_ns_host_result_img1(m0, m1, n0, n1)); + } + } + } + } + + isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + + return (isRealOk && isImgOk); + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 1fdd2f6d14826776b734b76815e192b4654e5b09..ea739c707153dd43c5a0c130e8fa916973e83b90 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,28 +5,190 @@ include_directories(BEFORE add_custom_target(examples) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(FILE_NAME) + add_dependencies(EXAMPLE_NAME FILE_NAME) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - add_dependencies(examples ${EXAMPLE_NAME}) - add_dependencies(check ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + + set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) + + #Do not build any DL examples if DL_KERNELS not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any FP8 examples if CK_ENABLE_FP8 not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") + message("removing fp8 example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any BF8 examples if CK_ENABLE_BF8 not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8") + message("removing bf8 example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(result EQUAL 0) + add_dependencies(${EXAMPLE_NAME} ${FILE_NAME}) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_dependencies(examples ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + + set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) + + #Do not build any DL examples if DL_KERNELS not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_dependencies(examples ${EXAMPLE_NAME}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}") + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") add_subdirectory(${subdir}) ENDIF() ENDFOREACH() diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1ba76a523eed58b60aa67b2a58b9af0825647f06 --- /dev/null +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -0,0 +1,116 @@ +# validate user-specified fmha_fwd API list +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(FMHA_FWD_ENABLE_APIS STREQUAL "all") + set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) +endif() + +foreach(api ${FMHA_FWD_ENABLE_APIS}) + if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) + message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") + endif() +endforeach() + +# "fwd" is a must-have api for the fmha_fwd example, add it if not specified +if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_ENABLE_APIS "fwd") +endif() + +string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") +endif() + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3 + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") +endif() + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +add_custom_command( + OUTPUT ${FMHA_BWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3 +) + +set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_FMHA_FWD}") +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) + +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_FMHA_BWD}") +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) +set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +if(FMHA_FWD_FAST_EXP2) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) +endif() +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) + +# conditionally enable call to the fwd_splitkv API in fmha_fwd example +if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +endif() + +# conditionally enable call to the fwd_appendkv API in fmha_fwd example +if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +endif() + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) + +target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c7ab296c3bbedd5c1358e45cf28fb29b8786c6f5 --- /dev/null +++ b/example/ck_tile/01_fmha/README.md @@ -0,0 +1,132 @@ +# fused multi-head attention + +This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +make tile_example_fmha_fwd -j +``` +This will result in an executable `build/bin/tile_example_fmha_fwd` + +## kernel +The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. + +There are 3 template parameters for this kernel template. +* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. + +## codegen +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change) +``` +args: + -v weather do CPU validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, -1 means equal to h (default:-1) + if not equal to h, then this is GQA/MQA case + -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) + -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + -d head dim for q, k (default:128) + -d_v head dim for v, -1 means equal to d (default:-1) + -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) + note when squant=1, this value will be modified by range_q/k + -range_q per-tensor quantization range of q. used if squant=1. (default:16) + -range_k per-tensor quantization range of k. used if squant=1. (default:16) + -range_v per-tensor quantization range of v. used if squant=1. (default:16) + -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) + -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias n or 0, no bias (default:n) + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -kname if set to 1 will print kernel name (default:0) + -init init method. ui, uniform random int, ni, normalized random int (default:uf) + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -drop_seed seed for random number generator (default:1) +-drop_offset offset for random number generator (default:0) + -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) +``` +Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with + batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case + +## support features +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default) + +### group/batch mode +Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below) + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. + +### alibi +alibi is supported + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. + +### attention mask +we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right. +Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out. +![](misc/gamc.png) + +Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. + +| mask case| cmdline | FA style | xformer style | +|----------|:-------------:|:-------------:|:-------------:| +| no mask | `-mask=0`(default) | | | +| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` | +| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` | +| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` | +| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` | + +Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right. + +### dropout +TBD + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. + +Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9dc656f6374594acb269949c839ad46e12759e8 --- /dev/null +++ b/example/ck_tile/01_fmha/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_fmha/codegen/__init__.py b/example/ck_tile/01_fmha/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/01_fmha/codegen/cmake_config.py b/example/ck_tile/01_fmha/codegen/cmake_config.py new file mode 100644 index 0000000000000000000000000000000000000000..03ebfd67021360c9f149d12232c790fa9ca06fb1 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cmake_config.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py new file mode 100644 index 0000000000000000000000000000000000000000..66691356ab7d882ff6d073c0e38c72cc416d6320 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +DTYPE_MAP = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp8" : "ck_tile::fp8_t" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +DROPOUT_MAP = { + "no" : "ck_tile::BlockDropoutBwd", + "dropout_wg32" : "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", + "dropout_wg16" : "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" +} + +DROPOUT_CHECK_MAP = { + "no" : "t.has_dropout == false", + "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", +} + +ROPE_MAP = { + "no" : "ck_tile::RotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" +} + +ROPE_CHECK_MAP = { + "no" : "rope_enum::none", + "inter" : "rope_enum::interleaved", + "half" : "rope_enum::half_rotated" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} diff --git a/example/ck_tile/01_fmha/codegen/ops/__init__.py b/example/ck_tile/01_fmha/codegen/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..096394c0c95e8833051a6c5de06f49528eeb8ebf --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -0,0 +1,804 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +BWD_DQDKDV_PIPELINE_MAP = { + "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP", + "kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP", + "kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile:: + sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>; +using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + false, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; +using fmha_dropout_{F_idx} = {F_dropout}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + {F_deterministic}, + fmha_mask_{F_idx}, + fmha_dropout_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}; + +using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + {F_skpad}, + {F_dpad}>>; + +using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + {F_skpad}, + {F_dvpad}>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, + {F_dtype}, + {F_mode}, + {F_pipeline_enum}, + fmha_mask_{F_idx}, + fmha_dropout_{F_idx}, + {F_bias}, + {F_dbias}, + {F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_deterministic}>; + +#include + +template <> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + ); +}} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>; + r = fmha_bwd_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + deterministic : str + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 64 == 0' + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + hdim_int = int(hdim) + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4 + F_rk2 : int # number of warps along k seqlen (not used) in gemm4 + F_wm0 : int # warp size along m in gemm0/gemm2/gemm4 + F_wn0 : int # warp size along n in gemm0/gemm2/gemm4 + F_wk0 : int # warp size along k in gemm0/gemm2/gemm4 + F_wm1 : int # warp size along m in gemm1/gemm3 + F_wn1 : int # warp size along n in gemm1/gemm3 + F_wk1 : int # warp size along k in gemm1/gemm3 + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" + +@dataclass +class FmhaBwdDQDKDVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_deterministic : str # + F_pipeline : str # + mask_impl : str # + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = DROPOUT_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_deterministic = BOOL_MAP[self.F_deterministic], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + if self.F_deterministic == 't' : n += '_deterministic' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad, + deterministic=self.F_deterministic + ) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], + '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"] + } + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + if ("wg32" in dropout): + continue + if (dpad == "t" or dvpad == "t"): + ppl = d[hdim_str][2] + k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + if receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = + ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 64, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = + typename ck_tile::BlockFmhaBwdOGradDotO; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel; + +using dot_do_o_trait_{F_idx} = + fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template <> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_convert_dq_trait_{F_idx} = + ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>; + +using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = + ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + {F_bm0}, + {F_bn0}, + {F_hdim}, + {F_mode}, + {F_deterministic}, + fmha_bwd_convert_dq_trait_{F_idx}>; + +using fmha_bwd_convert_dq_{F_idx} = + typename ck_tile::BlockFmhaBwdConvertQGrad; + +using fmha_bwd_convert_dq_kernel_{F_idx} = + ck_tile::FmhaBwdConvertQGradKernel; + +using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, + {F_dtype}, + {F_mode}, + {F_spad}, + {F_dpad}, + {F_deterministic}>; + +#include + +template <> +float fmha_bwd_convert_dq_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template <> +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, + fmha_bwd_args a) +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template <> +std::string fmha_bwd_convert_dq_get_name_() +{{ + using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdConvertQGradKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_spad : str # true/false + F_dpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int # + F_deterministic : str # + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_bm0, + F_bn0 = self.F_bn0, + F_spad = BOOL_MAP[self.F_spad], + F_dpad = BOOL_MAP[self.F_dpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy, + F_deterministic = BOOL_MAP[self.F_deterministic]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dpad == 't' : n += 'd' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + if self.F_deterministic == 't' : n += f'_deterministic' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + tile = d[hdim_str][0] + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, + F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) + gen.append(k) + + return gen + +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + kernels = get_bwd_convert_dq_blobs() + for kernel in kernels: + write_single_bwd_convert_dq_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + kernels = get_bwd_convert_dq_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ee1d22e74155a91df630494da6c1035f7fdf38 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -0,0 +1,524 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + fmha_warp_tile_{F_idx}, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + fmha_warp_tile_{F_idx}, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + ## '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd1d01c91a57e824a8ef7370af56190ccede4f5 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdApiTrait, + DTYPE_BITS, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +FMHA_FWD_APPENDKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + {F_bs}, + {F_bsk}, + {F_bd}, + {F_bdv}, + {F_vlayout}, + {F_rope}, + {F_pagedkv}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< + fmha_pipeline_problem_{F_idx}>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdAppendKVKernel, + fmha_pipeline_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, + {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + +#include + +template<> +float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp" +FMHA_FWD_APPENDKV_API=""" +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv})) {{ + using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; + return fmha_fwd_appendkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdAppendKVApiTrait: + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + bs : int # tile size along q seqlen + bsk : int # tile size along k seqlen + bd : int # tile size along qk gemm unroll + bdv : int # tile size along kv gemm unroll + vlayout : str + spad : str + skpad : str + dpad : str + dvpad : str + rope : str # key from ROPE_MAP + pagedkv : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\ + f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/' + else : return f'a.seqlen_q % {self.bs} == 0' + + @property + def skcheck(self) -> str: + # we do not check all the values in a.seqlen_k_ptr + return 'true' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bd} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bdv} == 0' + +@dataclass +class FmhaFwdAppendKVPipeline: + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_rope : str # key from ROPE_MAP + F_pagedkv : str # t/f + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_rope != 'no': n += f'_{self.F_rope}' + if self.F_pagedkv == 't': n += '_pagedkv' + return n + +class FmhaFwdAppendKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], + F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdAppendKVTileSize: + F_bs : int # tile size along q seqlen + F_bsk : int # tile size along k seqlen + F_bd : int # tile size along qk gemm unroll + F_bdv : int # tile size along kv gemm unroll + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdAppendKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaFwdAppendKVTileSize + F_pipeline : FmhaFwdAppendKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_APPENDKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bs = self.F_tile.F_bs, + F_bsk = self.F_tile.F_bsk, + F_bd = self.F_tile.F_bd, + F_bdv = self.F_tile.F_bdv, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_rope = ROPE_MAP[self.F_pipeline.F_rope], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy = self.F_tile.F_occupancy) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdAppendKVApiTrait: + return FmhaFwdAppendKVApiTrait( + hdim=str(self.F_hdim), + dtype=self.F_dtype, + bs=self.F_tile.F_bs, + bsk=self.F_tile.F_bsk, + bd=self.F_tile.F_bd, + bdv=self.F_tile.F_bdv, + vlayout=self.F_pipeline.F_vlayout, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + rope=self.F_pipeline.F_rope, + pagedkv=self.F_pipeline.F_pagedkv) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1), + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1), + '128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1), + '256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1) + } + else: + return None + +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + # NOTICE: it will be very complicated if we consider all the hdim_q padding cases while + # applying rotary embedding, so I just use 't' in inter/half pipelines + for vlayout in ['row', 'col']: + for pagedkv in ["t", "f"]: + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv)) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv)) + + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv)) + pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv)) + elif dtype in ['fp8', 'bf8']: + # rope/paged-kv is not supported + pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdAppendKVApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str in d.keys(): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + k = FmhaFwdAppendKVKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_appendkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py new file mode 100644 index 0000000000000000000000000000000000000000..b084e9d0fcde38d8786f19f769e731e5c0c16786 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -0,0 +1,758 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdTileSize, + FmhaFwdApiTrait, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_SPLITKV_PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", +} + +FMHA_FWD_SPLITKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_mask_{F_idx} = {F_mask}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; +using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape = ck_tile::TileFmhaShape, + fmha_warp_tile, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + fmha_warp_tile, + {F_vlayout}>; + +using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_squant}, + {F_pagedkv}, + kHasUnevenSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + fmha_shape, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait>; + +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if constexpr({F_mode} == false) {{ // batch mode + // we don't check every seqlen_k values for kvcache + if (a.seqlen_k_ptr != nullptr) {{ + kernel_runner::run(s, a); + // make sure F_bn0 is divisible by F_bk1 + }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + kernel_runner::run(s, a); + }} else {{ + kernel_runner::run(s, a); + }} + }} else {{ + kernel_runner::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_get_name_() +{{ + using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, + {F_dvpad}, + {F_lse}, + {F_squant}, + kLogMaxSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_hdim}, + {F_bm0}, + {F_bn1}, + {F_mode}, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + fmha_pipeline_problem>; + +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + false, false>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVCombineKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, + {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if (a.num_splits <= 8) {{ + kernel_runner<3>::run(s, a); + }} else if (a.num_splits <= 16) {{ + kernel_runner<4>::run(s, a); + }} else if (a.num_splits <= 32) {{ + kernel_runner<5>::run(s, a); + }} else if (a.num_splits <= 64) {{ + kernel_runner<6>::run(s, a); + }} else if (a.num_splits <= 128) {{ + kernel_runner<7>::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_combine_get_name_() +{{ + using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" +FMHA_FWD_SPLITKV_API=""" +#include + +template +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) +{{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + ); +}} + +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdSplitKVApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + mask : str + bias : str # + lse : str # + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + pagedkv : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.dvpad}-{self.pagedkv}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdSplitKVPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_squant : str # + F_pagedkv : str # t/f + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_squant == 't' : n += '_squant' + if self.F_pagedkv == 't' : n += '_pagedkv' + return n + +@dataclass +class FmhaFwdSplitKVCombinePipeline: + tag : str + + F_spad : str # true/false + F_dvpad : str # + F_lse : str # + F_squant : str # + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}' + if pn != '' : n += f'_{pn}' + if self.F_lse == 't' : n += '_lse' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdSplitKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdSplitKVCombineTileSize: + F_bm0 : int # tile size along q seqlen + F_bn1 : int # tile size along v head_dim + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdSplitKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdSplitKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdSplitKVApiTrait: + return FmhaFwdSplitKVApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + squant=self.F_pipeline.F_squant, + pagedkv=self.F_pipeline.F_pagedkv, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +@dataclass +class FmhaFwdSplitKVCombineKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdSplitKVCombineTileSize + F_pipeline : FmhaFwdSplitKVCombinePipeline + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn1 = self.F_tile.F_bn1, + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_mode = MODE_MAP[self.F_mode]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + ## '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdSplitKVCombineTileSize(16, 16, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, 32, -1), + ## '96' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, 64, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, 128, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + } + else: + return None + +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: + Pipeline = FmhaFwdSplitKVPipeline + Kernel = FmhaFwdSplitKVKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + # TODO: use async pipeline when compiler is more stable + if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: + # if True: + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + else: + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + if receipt == 1: + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/paged-kv kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdSplitKVApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if pipeline.F_pagedkv == 't': + # we only use batch mode kernels to handle (paged-) kvcache problems + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: + Pipeline = FmhaFwdSplitKVCombinePipeline + Kernel = FmhaFwdSplitKVCombineKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) + else: + assert False + return pipelines + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + gen.append(k) + + return gen + +def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: + file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME + file_path.write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_splitkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d76627a725f86122b8a734f14e0a2cb0585c616 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -0,0 +1,997 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_bwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("dbias", "0", "output bias gradient or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("deterministic", + "0", + "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " + "will not be used"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) +{ + double rtol = 1e-2; + double atol = 1e-2; + if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN + { + rtol = 3.2e-2; + atol = 3.2e-2; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k < 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k < 0) + seqlen_k = seqlen_q; + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + bool use_dbias = arg_parser.get_bool("dbias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + + if(use_dbias && bias.type != bias_enum::elementwise_bias) + { + std::cerr << "dbias only exists when bias type is elementwise" << std::endl; + return false; + } + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + bool deterministic = arg_parser.get_bool("deterministic"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaBwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using GemmDataType = typename TypeConfig::GemmDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using AccDataType = typename TypeConfig::AccDataType; + using DDataType = typename TypeConfig::DDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using ODataType = typename TypeConfig::ODataType; + using OGradDataType = typename TypeConfig::OGradDataType; + using QGradDataType = typename TypeConfig::QGradDataType; + using KGradDataType = typename TypeConfig::KGradDataType; + using VGradDataType = typename TypeConfig::VGradDataType; + using BiasGradDataType = typename TypeConfig::BiasGradDataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(3) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T + static_cast(2) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * real_seqlen_k * hdim_v + + sizeof(ODataType) * real_seqlen_q * hdim_v + + sizeof(OGradDataType) * real_seqlen_q * hdim_v + + sizeof(QGradDataType) * real_seqlen_q * hdim_q + + sizeof(KGradDataType) * real_seqlen_k * hdim_q + + sizeof(VGradDataType) * real_seqlen_k * hdim_v + + sizeof(LSEDataType) * real_seqlen_q); + } + } + + auto get_lengths = [&](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64; + const ck_tile::index_t nsplits = + deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1; + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor lse_host( + std::array{shape_batch, nhead, shape_seqlen_q}); + ck_tile::HostTensor d_host( + std::array{shape_batch, nhead, shape_seqlen_q}); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor dq_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor dk_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor dv_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor do_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor dbias_host( + use_dbias + ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor dq_acc_host( + i_perm + ? std::array{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q} + : std::array{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q}); + + if(init_method == 0) + { + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); + } + else if(init_method == 2) + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); + ck_tile::FillTrigValue{}(do_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + do_buf.ToDevice(do_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias + << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval + << ", deterministic:" << deterministic << ", mask:" << mask << std::flush; + + std::size_t workspace_size = + dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024); + + if(deterministic == 1) + { + std::cout << "\nDeterministic mode ON: " << workspace_size + << " MByte memory workspace allocated" << std::endl; + } + + auto fmha_traits = fmha_bwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f, + s_randval, + deterministic}; + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); + const ck_tile::index_t stride_bias = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_bias = 0; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q; + const ck_tile::index_t nhead_stride_dbias = + (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_bias = 0; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t split_stride_dq_acc = + (shape_batch * nhead * shape_seqlen_q * hdim_q); + + const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { + if(drop_prefs) + { + return std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + return std::make_pair(drop_seed, drop_offset); + } + }(); + + return fmha_bwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + do_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + dq_buf.GetDeviceBuffer(), + dk_buf.GetDeviceBuffer(), + dv_buf.GetDeviceBuffer(), + dbias_buf.GetDeviceBuffer(), + dq_acc_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, + stride_o, + stride_randval, + stride_do, + stride_q, // stride_dq_acc + stride_q, // stride_dq + stride_dk, + stride_dv, + stride_dbias, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + nhead_stride_randval, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_q, // nhead_stride_dq_acc + nhead_stride_q, // nhead_stride_dq + nhead_stride_k, // nhead_stride_dk + nhead_stride_v, // nhead_stride_dv + nhead_stride_dbias, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + batch_stride_randval, + batch_stride_do, + batch_stride_lsed, + batch_stride_q, // batch_stride_dq_acc + batch_stride_q, // batch_stride_dq + batch_stride_dk, + batch_stride_dv, + batch_stride_dbias, + split_stride_dq_acc, + mask.left, + mask.right, + static_cast(mask.type), + p_drop, + p_undrop, + drop_seed_offset}; + }(); + + float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + bool pass = true; + + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; + + randval_buf.FromDevice(randval_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + ck_tile::HostTensor p_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision + ck_tile::HostTensor p_dropped_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision + ck_tile::HostTensor p_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + + // reference + // S = scale * Q * K^T + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + AccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + ck_tile::reference_batched_softmax( + s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + + if(p_drop > 0) + { + p_hp_host_ref.ForEach( + [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + else + { + p_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + // O = P * V + ck_tile::reference_batched_gemm( + p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + + // clang-format off + // permute + if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); + else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); + + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); }); + // clang-format on + + q_host_refs.push_back(q_host_ref); + k_host_refs.push_back(k_host_ref); + v_host_refs.push_back(v_host_ref); + o_host_refs.push_back(o_host_ref); + p_hp_host_refs.push_back(p_hp_host_ref); + p_lp_host_refs.push_back(p_lp_host_ref); + if(p_drop > 0) + { + randval_host_refs.push_back(randval_host_ref); + } + } + + o_buf.ToDevice(o_host.data()); + lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); + dbias_buf.SetZero(); + dq_acc_buf.SetZero(); + + ck_tile::stream_config stream_config_v{ + nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + fmha_bwd(fmha_traits, fmha_args, stream_config_v); + + dq_buf.FromDevice(dq_host.data()); + dk_buf.FromDevice(dk_host.data()); + dv_buf.FromDevice(dv_host.data()); + dbias_buf.FromDevice(dbias_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o + ck_tile::HostTensor ds_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision + ck_tile::HostTensor ds_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision + ck_tile::HostTensor dp_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision + ck_tile::HostTensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + + // clang-format off + if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); + else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + // dP = dO@V x Z w/ dropout + // dP = dO@V w/o dropout + auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + ck_tile::reference_batched_gemm( + do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o + + if(p_drop > 0) + { + ck_tile::reference_batched_dropout( + dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + } + + // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) + ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + auto idx_gmo = idx_gmn; + idx_gmo[2] = o; + do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * + ck_tile::type_convert(o_host_refs[wb](idx_gmo)); + } + self(idx_gmn) = ck_tile::type_convert( + p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); + }); + + if(use_dbias) + { + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + + // dV = P_drop^T@dO^T + // dV = P^T@dO^T w/o dropout + auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + ck_tile::reference_batched_gemm( + p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + + // dQ = scale * dS@K^T + auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + ck_tile::reference_batched_gemm( + ds_lp_host_ref, + k_t_host_ref, + dq_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n + + // dK = scale * dS^T@Q^T + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + ck_tile::reference_batched_gemm( + ds_t_lp_host_ref, + q_t_host_ref, + dk_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m + + ck_tile::HostTensor dq_host_result( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_result( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_result( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + + // clang-format off + // permute + if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + + if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(use_dbias) + { + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + } + // clang-format on + + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); + bool dq_cur_pass = ck_tile::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck_tile::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck_tile::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); + + bool dbias_cur_pass = true; + if(use_dbias) + { + dbias_cur_pass = ck_tile::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); + } + pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); + if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b21a3257f850442d52deea806863e058c5d2849 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "mask.hpp" +#include "bias.hpp" + +#include +#include +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + void* dq_acc_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t split_stride_dq_acc; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + std::variant, std::pair> + drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) + { + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.split_stride_dq_acc); + } + else + { // create batch mode kernel arguments + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.batch_stride_dq, + args.batch_stride_dq_acc, + args.split_stride_dq_acc); + } + }(); + + dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +template +struct fmha_bwd_convert_dq_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_convert_dq_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + bool is_store_randval; + bool is_deterministic; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..14291715fb0b314d06f0b27699157753a316f452 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -0,0 +1,1512 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "rotary.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()" +#endif + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert( + "s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") + .insert("s_k", "-1", "seqlen_k (including new key/value), -1 means equal to s") + .insert("s_knew", + "0", + "seqlen_k for new key/value, 0 means not to use this at all; " + "-1 to choose s_knew in [1, s] randomly.") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed. same as xformer kv_padding") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale_s", + "0", + "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" + "note when squant=1, this value will be modified by range_q/k") + .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") + .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") + .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") + .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") + .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") + .insert("squant", + "auto", + "if using static quantization fusion or not. auto: fp8 will default use squant, " + "other will not\n" + "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " + "P and O.\n" + "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " + "range_p, range_o") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", + "uf", + "init method. ui, uniform random int, ni, normalized random int\n" + "uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, " + "quantization") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert( + "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") + .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") + .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe") + .insert("cache_batch_idx", "0", "whether to use index map to the kvcache") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || + ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + efficiency.push_back(0.f); + } + else + { + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + continue; + } + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return num_splits; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return num_splits; + } + + // tile size should match the generate.py + const int kM0 = 64; + const int kN1 = hdim_v; + + const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); + const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); + + if(num_splits < 1 && p_drop == 0.0f) + { + return num_splits_heuristic( + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + } + + return num_splits; +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k < 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + + ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API + if(seqlen_knew != 0) + { + std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; + seqlen_knew = 0; + } +#endif + if(seqlen_knew < 0) + { + seqlen_knew = randint(1, arg_parser.get_int("s"), seed); + } + + ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + if(0 < rotary_dim) + { + std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; + return false; + } + } +#if !CK_TILE_FMHA_FWD_APPENDKV_API + else if(0 < rotary_dim) + { + std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" + << std::endl; + rotary_dim = 0; + } +#endif + if(!(rotary_dim <= hdim_q)) + { + std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; + return false; + } + else if(!(rotary_dim % 16 == 0)) + { + std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; + return false; + } + + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) + { + std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" + << std::endl; + page_block_size = 0; + } +#endif + if(!(page_block_size % 128 == 0)) + { + std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" + << std::endl; + return false; + } + + bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API + if(use_cache_batch_idx) + { + std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } +#endif + if(0 < page_block_size && use_cache_batch_idx) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + // the input tensor layout for kvcache is same as batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + if(use_kvcache && mode != mode_enum::batch) + { + std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } + + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = + decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad"), + /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, + use_kvcache); + // compute kvcache seqlen_k (before appending knew/vnew) + auto cache_seqlen_ks = seqlen_ks; + std::transform(cache_seqlen_ks.begin(), + cache_seqlen_ks.end(), + cache_seqlen_ks.begin(), + [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); + +#if 0 + // clang-format off + std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; + // clang-format on +#endif + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale_s = arg_parser.get_float("scale_s"); + if(scale_s == .0f) + scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + std::string squant_str = arg_parser.get_str("squant"); + bool squant = [&]() { + if(squant_str == "auto") + { + if(data_type == "fp8") + return true; + else + return false; + } + else + return atoi(squant_str.c_str()) != 0 ? true : false; + }(); + + float range_q = arg_parser.get_float("range_q"); + float range_k = arg_parser.get_float("range_k"); + float range_v = arg_parser.get_float("range_v"); + float range_p = arg_parser.get_float("range_p"); + float range_o = arg_parser.get_float("range_o"); + + float dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float scale_p = 1.f; + float scale_o = 1.f; + + if(squant) + { + scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max); + scale_p = dtype_max / range_p; + // scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)] + scale_o = range_p * range_v / range_o / dtype_max; + } + + std::string vlayout = arg_parser.get_str("vlayout"); + bool lse = arg_parser.get_bool("lse"); + + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + mask_info mask = mask_info::decode( + arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore + + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + std::string init_method = arg_parser.get_str("init"); + + const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); + + ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(num_splits != 1) + { + std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl; + num_splits = 1; + } +#endif + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; + + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + + const ck_tile::index_t max_num_page_blocks = + (0 < page_block_size + ? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) + : 0); + + // legalize num_splits according to other options + if(num_splits < 1) + { + num_splits = override_num_splits_if_necessary( + batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + } + if(128 < num_splits) + { + std::cerr << "num_splits greater than 128 is not supported" << std::endl; + return false; + } +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < p_drop && (1 < num_splits || use_kvcache)) + { + std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" + << std::endl; + p_drop = 0.0f; + } +#endif + + static const auto get_lengths = [](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + bool is_v_rowmajor = vlayout == std::string("r"); + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + 0 < page_block_size + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) + : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode + ck_tile::HostTensor knew_host( + 0 < seqlen_knew + ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor v_host( + 0 < page_block_size + ? (is_v_rowmajor + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) + : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) + : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); + ck_tile::HostTensor vnew_host( + 0 < seqlen_knew + ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) + : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + + auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( + std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); + + ck_tile::HostTensor lse_acc_host( + 1 < num_splits || use_kvcache + ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor o_acc_host( + 1 < num_splits || use_kvcache ? std::array{shape_batch, + nhead, + num_splits, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); + + // batch mode of lse data layout is [batch, nhead, seqlen_q] + // group mode of lse data layout is [nhead, total_seqlen_q] + ck_tile::HostTensor lse_host( + lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + + ck_tile::HostTensor block_table_host( + 0 < page_block_size ? std::array{batch, max_num_page_blocks / batch} + : std::array{1, 1}); + + ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx + ? std::array{batch} + : std::array{1}); + + if(init_method == "ui" || init_method == "0") + { + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "uf" || init_method == "1") + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + } + else if(init_method == "tf" || init_method == "2") + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); + ck_tile::FillTrigValue{}(bias_host); + } + else if(init_method == "ufq" || init_method == "uf:q" || + init_method == "3") // suitable for fp8 quantization + { + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(knew_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(vnew_host); + + // bias_fp8 = qscale_bias * bias_fp32 + float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); + // Assume bias is in [-1.f, 1.f] in original fp32 + ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == static_cast(nhead)); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); + iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0); + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf( + use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem cache_seqlen_k_buf( + need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + knew_buf.ToDevice(knew_host.data()); + v_buf.ToDevice(v_host.data()); + vnew_buf.ToDevice(vnew_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); + seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); + cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); + rotary_cos_buf.ToDevice(rotary_cos_host.data()); + rotary_sin_buf.ToDevice(rotary_sin_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + block_table_buf.ToDevice(block_table_host.data()); + cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if(permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if(iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout; +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < rotary_dim) + { + std::cout << ", rotary_dim:" << rotary_dim << "(" + << (is_rotary_interleaved ? "inter" : "half") << ")"; + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits) + { + std::cout << ", num_splits:" << num_splits; + } + if(0 < page_block_size) + { + std::cout << ", page_block_size:" << page_block_size; + } + if(use_cache_batch_idx) + { + std::cout << ", cache_batch_idx:" << use_cache_batch_idx; + } +#endif + std::cout << std::flush; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + + if constexpr(std::is_same_v>) + { + traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved + : rope_enum::half_rotated) + : rope_enum::none); + } + else // fmha_fwd_traits or fmha_splitkv_traits + { + traits.is_group_mode = (mode == mode_enum::group); + traits.mask_type = mask.type; + traits.bias_type = bias.type; + traits.has_lse = lse; + traits.do_fp8_static_quant = squant; + + if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + } + } + }; + + const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_knew : nhead_k * seqlen_knew; + }(); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = (hdim_v); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = + (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) + : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); + const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + else + return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + }(); + const ck_tile::index_t nhead_stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? seqlen_knew * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_knew : seqlen_knew; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = + (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) + : (nhead_k * shape_seqlen_k * hdim_q)); + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); + const ck_tile::index_t batch_stride_v = + (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) + : (nhead_k * hdim_v * shape_seqlen_k)); + const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + if constexpr(std::is_same_v>) + { + args.knew_ptr = knew_buf.GetDeviceBuffer(); + args.vnew_ptr = vnew_buf.GetDeviceBuffer(); + args.seqlen_knew = seqlen_knew; + + args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); + + args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr); + args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr); + args.rotary_dim = rotary_dim; + args.has_mask = (mask.type != mask_enum::no_mask); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.stride_knew = stride_knew; + args.stride_vnew = stride_vnew; + args.nhead_stride_knew = nhead_stride_knew; + args.nhead_stride_vnew = nhead_stride_vnew; + args.batch_stride_knew = batch_stride_knew; + args.batch_stride_vnew = batch_stride_vnew; + } + else // fmha_fwd_args or fmha_fwd_splitkv_args + { + args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(); + args.lse_ptr = lse_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = + (use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.stride_bias = + (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + if constexpr(std::is_same_v>) + { + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + { + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + } + } + else if constexpr(std::is_same_v>) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.num_splits = num_splits; + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } + } + }; + + const float appendkv_ave_time = [&] { +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(need_append_kvcache) + { + fmha_fwd_appendkv_traits fwd_appendkv_traits; + init_traits(fwd_appendkv_traits); + + fmha_fwd_appendkv_args fwd_appendkv_args; + init_args(fwd_appendkv_args); + + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + } +#endif + return 0.0f; + }(); + + const float fwd_ave_time = [&] { +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits || use_kvcache) + { + fmha_fwd_splitkv_traits fmha_splitkv_traits; + init_traits(fmha_splitkv_traits); + + fmha_fwd_splitkv_args fmha_splitkv_args; + init_args(fmha_splitkv_args); + + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + } +#endif + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); + + return fmha_fwd(fmha_traits, fmha_args, stream_config); + }(); + + if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + const float ave_time = (appendkv_ave_time + fwd_ave_time); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + + auto p_compute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::identity{}; + }(); + + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool pass = true; + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = + (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // optionally apply RoPE to the q_host_ref + if(0 < rotary_dim) + { + decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); + + ck_tile::reference_batched_rotary_position_embedding( + q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro, + /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); + + q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) { + if(i_perm) { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); + }); + } else { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); + }); + } + } else +#endif + { + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Knew to the end of K + if(0 < seqlen_knew) + { + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); + + // optionally apply RoPE to the knew_host_ref + auto* real_knew_host_ref = &knew_host_ref; + std::optional knew_host_ref_ro; + if(0 < rotary_dim) + { + knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + + ck_tile::reference_batched_rotary_position_embedding( + knew_host_ref, + rotary_cos_slice, + rotary_sin_slice, + is_rotary_interleaved, + knew_host_ref_ro.value()); + + real_knew_host_ref = &knew_host_ref_ro.value(); + } + + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); + }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) { + if(is_v_rowmajor) { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); + }); + } + } + else + { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); + }); + } + } + } else +#endif + { + if(is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else + { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + } + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Vnew to the end of V + if(0 < seqlen_knew) + { + ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); + if(is_v_rowmajor) + { + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); + } + else + { + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); + } + + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); + }); + } +#endif + // clang-format on + + // reference + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s)); + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + if(lse) + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func); + } + + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } + + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::check_err( + o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "OUT mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + }); + + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..251e61bc763d75b67ab1b8b2f7293906702888d3 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/fmha.hpp" + +#include "bias.hpp" +#include "mask.hpp" +#include "rotary.hpp" + +#include +#include +#include + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // kvcache mode (use same kernel as batch mode): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o_acc; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.has_mask, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); + + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); + +template +struct fmha_fwd_splitkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_appendkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool has_dropout; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, + fmha_fwd_appendkv_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1b6664ccc05880e2bec2fc5b5b0f7daf1a65f9 --- /dev/null +++ b/example/ck_tile/01_fmha/generate.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import pkgutil +import sys +from typing import List, Optional + +import codegen.ops +from codegen.cmake_config import * + + +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 + +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + if full_module_name not in sys.modules: + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) +unwanted_prefix = 'fmha_' +handlers = dict( + [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, + (op.list_blobs, op.write_blobs)) for op in ops] +) +assert 0 < len(handlers) + +def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + + for api in api_list: + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, mask_impl) + +# list all the files that will be generated +def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: + assert output_file is not None + file_path = Path(output_file) + + # create an empty file / drop its contents if it exists + open(file_path, "w").close() + + for api in api_list: + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, mask_impl) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK fmha kernel", + ) + parser.add_argument( + "-d", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", + default='fwd', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration" + ) + + args = parser.parse_args() + api_list = args.direction.split(',') + if args.list_blobs is not None: + list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + else: + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c77b700b16c2538c3490305e62bf06ad1eb659d3 --- /dev/null +++ b/example/ck_tile/01_fmha/mask.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_fmha/misc/gamc.png b/example/ck_tile/01_fmha/misc/gamc.png new file mode 100644 index 0000000000000000000000000000000000000000..2c96951f30f99ff0345706c4f3031ab2a623ce1a Binary files /dev/null and b/example/ck_tile/01_fmha/misc/gamc.png differ diff --git a/example/ck_tile/01_fmha/rotary.hpp b/example/ck_tile/01_fmha/rotary.hpp new file mode 100644 index 0000000000000000000000000000000000000000..346f2a5e7ee5943234d02f750c76c603c9a43fd4 --- /dev/null +++ b/example/ck_tile/01_fmha/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..cfd792906cef94a35b66e2930859b1c21dbe32b7 --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_bwd.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..599c595a757a42c6b2e4d5cc53273cf3cf42a7ac --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -0,0 +1,31 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done + +for perm in 0 1 ; do + +$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 + +done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b5e6778aa55d9095ef45433f87dbc7231fad0a70 --- /dev/null +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +#get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run verification tests +example/ck_tile/01_fmha/script/smoke_test_fwd.sh +example/ck_tile/01_fmha/script/smoke_test_bwd.sh + +#run performance benchmarks +export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" +print_log_header $fmha_fwd_log $env_type $branch $host_name +example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log + +export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" +print_log_header $fmha_bwd_log $env_type $branch $host_name +example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log + diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..5ba3425e261c720d0bac55106a65319d9f5cedb9 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -0,0 +1,35 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1' +set -x +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 256 ; do +for mode in 0 1 ; do +for bias in "n" "a" ; do +for dbias in 0 ; do +for p_drop in 0.0 0.2 ; do +for deterministic in 0 ; do + +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done +done +set +x diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..b867cd6c07f050063e41b8a1a31f0b6e1c610307 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "fp16" "bf16" ; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for vlayout in "r" "c" ; do + for hdim in 32 64 128 256 ; do + for lse in 0 1 ; do + for bias in "n" "e" "a" ; do + for p_drop in 0.0 0.2 ; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; +} + +run_fp8_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp16_appendkv_tests() { + for s in $(seq 63 1 65) ; do + for s_k in 65 129 ; do + for s_knew in 0 64 $s_k ; do + for hdim in 32 64 128 256 ; do + for ri in 0 1 ; do + for rdim in 0 16 32 $hdim ; do + for page_block_size in 0 128 ; do + for cache_batch_idx in 0 1 ; do + + $EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done +} + +set -x + +run_fp16_bf16_tests +run_fp8_tests + +if [ $TEST_APPENDKV -eq 1 ] ; then + run_fp16_appendkv_tests +fi + +set +x diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..996032a7177977867dd00f564aea1ceacba86ed1 --- /dev/null +++ b/example/ck_tile/01_fmha/utils.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/container/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +std::vector to_seqstarts(ck_tile::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min = -1, // if not negative, clamp min + int32_t seqlen_max = -1, // if not negative, clamp max + std::optional seed = std::nullopt) +{ + assert(0 < count); + + seqlen_min = (0 < seqlen_min ? seqlen_min : 1); + seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); + assert(seqlen_min <= seqlen_max); + + std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is in range [seqlen_min, seqlen_max] + if(seqlens[to_decrease] == seqlen_min) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + if(seqlens[to_increase] >= seqlen_max) + { + continue; + } + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +std::vector generate_seqstarts(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min = -1, + int32_t seqlen_max = -1, + std::optional seed = std::nullopt) +{ + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed)); +} + +// return random integer generated uniformly in range [low, high] +template +auto randint(Int low, Int high, std::optional seed = std::nullopt) + -> std::enable_if_t, Int> +{ + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution dist(low, high); + return dist(engine); +} + +// return random integers generated uniformly in range [low, high] +template +auto randints(ForwardIterator first, + ForwardIterator last, + Int low, + Int high, + std::optional seed = std::nullopt) + -> std::enable_if_t> +{ + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution dist(low, high); + + std::generate(first, last, [&] { return dist(engine); }); +} + +/* + * decode the seqlen string from cmdline + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +std::tuple, + std::vector, + std::vector> +decode_seqlen(mode_enum mode, + ck_tile::index_t batch, + std::string q_val, + std::string k_val, + std::string k_pad_val, + ck_tile::index_t seqlen_k_min = 0, + bool use_kvcache = false, + std::optional seed = std::nullopt) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + if(mode == mode_enum::batch) + { + ck_tile::index_t q = _S2I_(q_val); + ck_tile::index_t k = _S2I_(k_val); + + auto s_q = std::vector(batch, q); + auto s_k = [&] { + const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); + std::vector seqlen_ks(batch, seqlen_k_max); + + if(1 < batch && use_kvcache) + { + // to keep the original s_k value, we always use seqlen_k_max in first batch + randints(std::next(seqlen_ks.begin()), + seqlen_ks.end(), + seqlen_k_min, + seqlen_k_max, + seed); + return seqlen_ks; + } + + return seqlen_ks; + }(); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + + // s_k should be greater than or equal to seqlen_k_min if provided + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + ck_tile::index_t idx = 0; + std::string::size_type pos_q = 0; + std::string::size_type pos_k = 0; + std::string::size_type pos_kp = 0; + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + while(true) + { + auto found_q = q_val.find(',', pos_q); + auto found_k = k_val.find(',', pos_k); + auto found_kp = k_pad_val.find(',', pos_kp); + + ck_tile::index_t q = _S2I_( + q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); + ck_tile::index_t k = _S2I_( + k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); + ck_tile::index_t kp = _S2I_(k_pad_val.substr( + pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + + // s_k should be greater than or equal to seqlen_k_min + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + + idx++; + if(found_q == std::string::npos || idx >= batch) + { + break; + } + pos_q = found_q + 1; + pos_k = found_k == std::string::npos ? pos_k : found_k + 1; + pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; + } + if(idx < batch) + { + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed); + auto rem_k = + generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +#undef _S2I_ +} + +int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = std::atoi(v); + return r; +} + +template +std::enable_if_t> iota_shuffle(RandomAccessIterator first, + RandomAccessIterator last, + Int value, + std::optional seed = std::nullopt) +{ + std::iota(first, last, value); + + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::shuffle(first, last, engine); +} diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1bf74bc0553296f498004c024477034dd31797d0 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -0,0 +1,44 @@ +set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") +set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") +if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all") + set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS}) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs +) + +set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") + +message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") +add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) + +set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3573d70cd2615be0e1cc65a62613e4de0015f636 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/README.md @@ -0,0 +1,85 @@ +# Layernorm2D forward + +This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation. + +# Implementation and feature support + +## welford online algorithm +We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`. + +## mean/variance save +In training case the mean/variance need to store out (TBD, not supported yet) + +## prenorm/postnorm + +![](misc/pnorm.png) + +since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default). + +## smooth-quant/dynamic-quant +we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen) +![](misc/dquant.png) + +``` +# assume output int8, hidden_states is [m, n] shape and in fp16/bf16 +# [m, 1] +per_token_amax, _ = torch.max( + input=torch.abs(hidden_states), + dim=-1, + keepdim=True +) +per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0 + +# quant hidden_states +hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8) + +return hidden_states, per_token_scale +# hidden_states now is int8 will feed to next layer as intput +# per_token_scale will be used as dequant factor later layer +``` + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_layernorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_example_layernorm2d_fwd` + +## example +``` +args: + -m m dimension (default:3328) + -n n dimension (default:4096) + -stride stride per row, if -1 then equal to n (default:-1) + -e epsilon (default:1e-5) + -save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0) + -v cpu validation or not (default:1) + -kname print kernel name or not (default:1) + -prec_i input precision (default:fp16) + -prec_o output precision, set auto will be the same as input (default:auto) + -prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) + -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) + -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) + -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + +``` + +## limitations +Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. + +``` +# some case +# standard fp16 layernorm 2d, m=10. n=1024 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1 + +``` diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9e432a4f1ff914abaa851ffc0dacf7d0716f56 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -0,0 +1,680 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + +def get_if_str(idx, total, lase_else = True): + if idx == 0: + return 'if' + elif idx < total - 1: + return 'else if' + else: + if lase_else: + return 'else' + else: + return 'else if' + +FUSED_ADD_ENUM_STR_MAP = [ + 'no', + 'pras', # pre-norm + 'pra' ] # post-norm + +FUSED_FUSED_SWEEP_STR_MAP = [ + 'no', + 'dquant' ] + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::fp16_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t'} + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +class layernorm_fwd_codegen: + API_TRAITS_DEFINE = """ +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct layernorm2d_fwd_traits_ +{ + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using XScaleDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; +}; + +template +using traits_ = layernorm2d_fwd_traits_; +""" + API_COMMON_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" +#include +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = layernorm2d_fwd_args; + +{F_traits_define} + +template +float layernorm2d_fwd_(const S& s, A a) +{{ + using XDataType = typename Traits_::XDataType; + using YDataType = typename Traits_::YDataType; + using XScaleDataType = typename Traits_::XScaleDataType; + using YScaleDataType = typename Traits_::YScaleDataType; + using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; + + using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kFusedAdd), + static_cast(Traits_::kFusedQuant)>; + using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename LayerNormTypeConfig::XScaleDataType, + typename LayerNormTypeConfig::YScaleDataType, + typename Traits_::Shape, + PipelineTraits>; + + using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + + using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; + + using Epilogue = std::conditional_t; + + using Kernel = ck_tile::Layernorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); +}} + +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); + +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{{ + float r = -1; +{F_dispatch} + return r; +}} + +""" + + API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ +{F_per_n_case} + }} +""" + API_PER_N_CASE=""" {F_if} {F_N_COND} {{ +{F_inner_dispatch} + }} +""" + API_INNER_CASE=""" {F_if} {F_VEC_COND} + r={F_instance_func}(s, a); +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_api_common.hpp" + +// clang-format off +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep +{F_instance_def} +// clang-format on + +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + class k_fuesd_add_enum(IntEnum): + F_NO_ADD = 0 + F_PRE_ADD = 1 + F_PRE_ADD_STORE_RESIDUAL = 2 + + class k_fused_sweep_enum(IntEnum): + F_NO_SWEEP = 0 + F_RENORM = 1 + F_DYNAMIC_QUANT = 2 + + @dataclass + class k_traits: + F_kPadN : bool + F_kSaveMeanInvStd : bool + F_kTwoPass : bool + F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum + F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum + + @dataclass + class k_shape: + F_BlockTile : List[int] + F_WarpPerBlock : List[int] + F_WarpTile : List[int] + F_Vector_ : List[int] + @property + def F_BlockSize(self) -> int: + return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + + @dataclass + class k_problem: + F_XDataType : str + F_GammaDataType : str + F_BetaDataType : str + F_ComputeDataType : str + F_YDataType : str + F_MeanDataType : str + F_InvStdDataType : str + F_BlockShape : str + F_Traits : Any #k_traits + + @dataclass + class k_pipeline_one_pass: + F_Problem : Any #k_problem + + @dataclass + class k_pipeline_two_pass: + F_Problem : Any #k_problem + + @dataclass + class default_2d_epilogue_problem: + F_AccDataType : str + F_ODataType : str + F_kPadM : bool + F_kPadN : bool + + @dataclass + class default_2d_epilogue: + F_problem : Any + + @dataclass + class k_kernel: + F_pipeline : Any + F_epilogue : Any + + @dataclass + class h_traits: + F_XDataType : str + F_YDataType : str + F_XScaleDataType : str + F_YScaleDataType : str + F_Repeat_M : int + F_Repeat_N : int + F_ThreadPerBlock_M : int + F_ThreadPerBlock_N : int + F_Vector_N : int + F_kPadN : bool + F_kSaveMeanInvStd_ : bool + F_kFastFDiv_ : bool + F_kTwoPass_ : bool + F_kFusedAdd : int + F_kFusedQuant : int + + @property + def trait_name(self) ->str: + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + return t_ + + # string when calling this kernel + @property + def call_name(self) -> str: + return f'layernorm2d_fwd_>' + + # string when define this kernel + @property + def def_name(self) -> str: + return f'template float layernorm2d_fwd_>(const S&, A);' + + # this class hold kernel under same source file + @dataclass + class h_instance: + F_DataTypePair : str + F_N : str + F_add : int + F_sweep : int + instance_list : List[Any] # List[h_traits] + + @property + def name(self) -> str: + prec_i, prec_o = self.F_DataTypePair.split(',') + dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' + nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_add != 0: + nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + if self.F_sweep != 0: + nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + return nnn + + @property + def instance_name(self) ->str: + return self.name + + @property + def content(self) ->str: + instance_defs = '' + for ins in self.instance_list: + instance_defs += ins.def_name + '\n' + return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return 'layernorm2d_fwd_api' + + @property + def name_common_header(self) -> str: + return 'layernorm2d_fwd_api_common' + + @property + def content_api(self) -> str: + # 1 sort based on dtype + t_dtype_dict = dict() + blobs = self.get_blobs() + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + n_str = '' + for i_n, n_ in enumerate(blob_per_t): + blob_per_n = blob_per_t[n_] + inner_str = "" + for i_b, b_ in enumerate(blob_per_n): + # generate single kernel instance file + #vec_str = "" + for i_ins, ins in enumerate(b_.instance_list): + idx_in_n = i_b * len(b_.instance_list) + i_ins + len_in_n = len(blob_per_n) * len(b_.instance_list) + # _if = 'if' if i_ins == 0 else 'else if' + if ins.F_kFusedQuant == 0: + _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + elif ins.F_kFusedQuant == 1: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) + elif ins.F_kFusedQuant == 2: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + f_sweep_cond = _sweep_cond) + inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND = _cond, F_instance_func=ins.call_name) + #inner_str = inner_str + vec_str + n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + prec_i, prec_o = dtype_.split(',') + d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + + api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + return api_base + + @property + def content_common_header(self) -> str: + return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) + + def get_blobs(self): + h_traits = layernorm_fwd_codegen.h_traits + h_instance = layernorm_fwd_codegen.h_instance + + dynamic_quant_out_dtype = ['int8'] + # some predefined support range + # (prec_i,prec_o) for simplicity this string will be used as key for dict + scale_list = [('fp32,fp32')] + dtype_list = [('fp16,fp16'), ('bf16,bf16'), + ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out + #fused_add_list = [0, 1, 2] + #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + fused_add_list = [0, 1] + fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant + + # rm rn tm tn vn pd mv fdiv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]} + total_blob = list() + for hs_key in h_trait_dict: + hs = h_trait_dict[hs_key] + current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N + for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + prec_i, prec_o = dtype.split(',') + scale_x, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1: + continue # skip non dynamic quant case + if fused_quant == 1 and hs_key == 'big': + continue + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_XScaleDataType = scale_y + h_.F_YScaleDataType = scale_x + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + return total_blob + + def list_blobs(self) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'layernorm2d_fwd_blobs.txt' + blobs = self.get_blobs() + with list_p.open('w') as list_f: + # api related file + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # kernel instance file + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self) -> None: + w_p = Path(self.working_path) + (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + blobs = self.get_blobs() + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() + + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK layernorm kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd[all]', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + # this script have 2 modes + # 1) list_blobs mode, will generate a txt file with all the files going to be generated. + # this is useful in build system like cmake to construct source code dependency, by + # reading the content out of this file + # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework + # like FA, only need to use this mode + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file, " + ) + + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-t", + "--traits", + default="all", + required=False, + help="enable/disable some feature. default generate all" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt." + ) + + args = parser.parse_args() + + # print(f'{args.list_blobs}-{args.gen_blobs}') + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b49c04619d54c6c401128739a236259e04c54dd7 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -0,0 +1,431 @@ +#include "ck_tile/host.hpp" +#include "layernorm2d_fwd.hpp" +#include +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("x_stride", "-1", "x row_stride, if -1 then equal to n") + .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n") + .insert("y_stride", "-1", "y row_stride, if -1 then equal to n") + .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec_i", "fp16", "input precision") + .insert("prec_o", "auto", "output precision, set auto will be the same as input") + .insert("prec_sx", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1") + .insert("prec_sy", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride"); + if(xr_stride < 0) + xr_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride"); + if(yr_stride < 0) + yr_stride = n; + float epsilon = arg_parser.get_float("e"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sx == "auto") + { + prec_sx = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + if(fused_quant == 1 && prec_o != "int8") + { + std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; + return false; + } + + assert(x_stride >= n); + + using TypeConfig = LayerNormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using BetaDataType = typename TypeConfig::BetaDataType; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + + using MeanDataType = + std::conditional_t; + using InvStdDataType = + std::conditional_t; + + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor gamma_host({n}); + ck_tile::HostTensor beta_host({n}); + + ck_tile::HostTensor x_residual_host({m, n}, {xr_stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {yr_stride, 1}); + + ck_tile::HostTensor y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {y_stride, 1}); + + ck_tile::HostTensor mean_host_ref({m}); + ck_tile::HostTensor invStd_host_ref({m}); + ck_tile::HostTensor y_scale_host_ref({m}); + ck_tile::HostTensor y_scale_host_dev({m}); + + ck_tile::HostTensor x_scale_host({n}); + ck_tile::HostTensor x_scale_host_dev({n}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + beta_buf.ToDevice(beta_host.data()); + x_residual_buf.ToDevice(x_residual_host.data()); + x_scale_buf.ToDevice(x_scale_host.data()); + + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_o) + { + base_str += "|" + prec_o; + } + if(fused_quant == 1) + { + base_str += std::string("(") + prec_sy + ")"; + } + return base_str; + }(); + + std::cout << "[" << prec_str << "]" + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride + << ", yr_stride:" << yr_stride << std::flush; + + layernorm2d_fwd_traits traits{ + prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; + + layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, + gamma_buf.GetDeviceBuffer(), + beta_buf.GetDeviceBuffer(), + + y_buf.GetDeviceBuffer(), + fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, + nullptr, // p_mean, unsupported yet + nullptr, // p_invStd, unsupported yet + + epsilon, + m, + n, + x_stride, // x row_stride + xr_stride, // x residule row stride + y_stride, // y row stride + yr_stride}; // y residule row stride + + float ave_time = layernorm2d_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + + sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + if(fused_add != 0) + { + // fused pre_add/pre_add_store + // TODO we accumulate directly to x_host for simplcity here... + + std::transform(x_host.mData.cbegin(), + x_host.mData.cend(), + x_residual_host.mData.cbegin(), + x_host.mData.begin(), + [](auto x_, auto r_) { + auto o_ = ck_tile::type_convert(x_) + + ck_tile::type_convert(r_); + return ck_tile::type_convert(o_); + }); + } + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + + if(fused_quant != 0) + { + auto dquant_functor = [&](int m_, auto& o_, auto& acc_) { + int N_ = acc_.mDesc.get_lengths()[1]; + if(fused_quant == 1) + { + for(int n_ = 0; n_ < N_; n_++) + { + // input smooth outlier + acc_(m_, n_) = + acc_(m_, n_) * ck_tile::type_convert(x_scale_host(n_)); + } + } + ComputeDataType absmax = static_cast(0); + for(int n_ = 0; n_ < N_; n_++) + { + const auto a = ck_tile::abs(acc_(m_, n_)); + absmax = a > absmax ? a : absmax; + } + // printf("cpu:absmax:%f\n", absmax); + ComputeDataType y_scale = absmax / static_cast(127.0); + y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); + for(int n_ = 0; n_ < N_; n_++) + { + o_(m_, n_) = ck_tile::type_convert(acc_(m_, n_) / y_scale); + } + }; + + ck_tile::reference_layernorm2d_fwd(x_host, + gamma_host, + beta_host, + y_host_ref, + mean_host_ref, + invStd_host_ref, + epsilon, + dquant_functor); + } + else + { + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + } + + y_buf.FromDevice(y_host_dev.data()); + + ck_tile::HostTensor y_residual_host_dev({m, n}, {yr_stride, 1}); + if(fused_add == 1) + { + y_residual_buf.FromDevice(y_residual_host_dev.data()); + } + + auto [rtol, atol] = get_elimit(); + + if(x_stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + if(fused_add == 1) + { + pass &= ck_tile::check_err(y_residual_host_dev, + x_host, + std::string("ADD Error: Incorrect results!"), + rtol, + atol); + } + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * y_stride, + y_host_dev.begin() + i_r * y_stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * y_stride, + y_host_ref.begin() + i_r * y_stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + if(fused_add == 1) + { + std::vector y_residual_host_dev_row( + y_residual_host_dev.begin() + i_r * yr_stride, + y_residual_host_dev.begin() + i_r * yr_stride + n); + std::vector y_residual_host_ref_row( + x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n); + pass &= ck_tile::check_err(y_residual_host_dev_row, + y_residual_host_ref_row, + std::string("ADD[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + if(fused_quant == 1) + { + y_scale_buf.FromDevice(y_scale_host_dev.data()); + pass &= ck_tile::check_err(y_scale_host_dev, + y_scale_host_ref, + std::string("SCALE Error: Incorrect results!"), + rtol, + atol); + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sx == "auto") + { + prec_sx = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + int save_mv = arg_parser.get_int("save_mv"); + + // no dynamic quant case + if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + + // dynamic quant case, only in inference + else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a0f2db0e8a478da5a4302fe7439aa1354d3b923a --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/layernorm2d.hpp" +#include + +template +struct LayerNormTypeConfig; + +template +struct LayerNormTypeConfig +{ + using XDataType = ck_tile::half_t; + using YDataType = OutType; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; + using ComputeDataType = float; + using XScaleDataType = XScaleDataType_; + using YScaleDataType = YScaleDataType_; +}; + +template +struct LayerNormTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using YDataType = OutType; + using GammaDataType = ck_tile::bf16_t; + using BetaDataType = ck_tile::bf16_t; + using MeanDataType = ck_tile::bf16_t; + using InvStdDataType = ck_tile::bf16_t; + using ComputeDataType = float; + using XScaleDataType = XScaleDataType_; + using YScaleDataType = YScaleDataType_; +}; + +// runtime args +struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs +{ +}; + +// This is the public API, will be generated by script +struct layernorm2d_fwd_traits +{ + std::string prec_i; // input precision + std::string prec_o; // output precision + + // if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set + // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise + // can set arbitrary(will skip check) + std::string prec_sx; // x-scale, used for [1*N] input smooth quant + std::string prec_sy; // y-scale, used for [M*1] output for next layer + + bool save_mean_var; // + int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/02_layernorm2d/misc/dquant.png b/example/ck_tile/02_layernorm2d/misc/dquant.png new file mode 100644 index 0000000000000000000000000000000000000000..28b1a61a14ea6774191fc2ac54f195cb86477f9b Binary files /dev/null and b/example/ck_tile/02_layernorm2d/misc/dquant.png differ diff --git a/example/ck_tile/02_layernorm2d/misc/pnorm.png b/example/ck_tile/02_layernorm2d/misc/pnorm.png new file mode 100644 index 0000000000000000000000000000000000000000..65a27e8751fa316d585c9b7d0340f3c425a71ec1 Binary files /dev/null and b/example/ck_tile/02_layernorm2d/misc/pnorm.png differ diff --git a/example/ck_tile/02_layernorm2d/script/perf_test.sh b/example/ck_tile/02_layernorm2d/script/perf_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..5a34e19280ded743fa4e846fbda4ac1311614ed0 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/script/perf_test.sh @@ -0,0 +1,37 @@ +#!/bin/sh +EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b7fd354bb8e05647dc66ee9f4699757bc8378d8b --- /dev/null +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -0,0 +1,34 @@ +#!/bin/sh +EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" + +for fquant in "" "-fquant=1 -prec_o=int8"; do +for pr_i in "fp16" "bf16" ; do +for fadd in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 +done +done +done diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8ae46cadc65fa56ae93dbd8668c453df9a7451f3 --- /dev/null +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) +add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e9ffe72a9152d22b62680369ac03ab569c0eecce --- /dev/null +++ b/example/ck_tile/03_gemm/README.md @@ -0,0 +1,34 @@ +# GEMM Matrix Multiplication + +This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the gemm calculation +make tile_example_gemm_basic -j +# The memory bound pipeline on the gemm calculation +make tile_example_gemm_mem_pipeline -j +``` +This will result in an executable `build/bin/tile_example_gemm_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7d869344238d1ecf3a2518683d961e2d5292c96 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_basic.hpp" + +template +float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool kTilePermute = false; + // The rank and permutation will also be generate out by the CodeGen part. + constexpr ck_tile::index_t kOutputRank = 2; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + // Whether doing the CShuffle (transpose before the global memory), depending on the output + // layout. + constexpr bool CShuffleEpilogue = + std::is_same_v; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = std::conditional_t< + CShuffleEpilogue, + ck_tile::CShuffleEpilogue>, + ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>>; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; + using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + + auto kargs = Kernel::MakeKargs(args.p_a, + args.p_b, + args.p_c, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_C); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..23e99bc2a88c7584b6ce9be2e59357cb801bf068 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -0,0 +1,92 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +struct gemm_basic_args +{ + const void* p_a; + const void* p_b; + void* p_c; + ck_tile::index_t kbatch; + ck_tile::index_t M; + ck_tile::index_t N; + ck_tile::index_t K; + ck_tile::index_t stride_A; + ck_tile::index_t stride_B; + ck_tile::index_t stride_C; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("b", "1", "batch size") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ff9d8bad3275bfb1f8219414a8c522fe7b3cdeca --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_basic.hpp" + +template +float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +{ + // ToDo: This will be modified by the codegen code later. + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = true; + constexpr bool kPadN = true; + constexpr bool kPadK = true; + + constexpr int kBlockPerCu = 1; + + // =============================================== + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + + using Traits = ck_tile::TileGemmTraits; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKargs(args.p_a, + args.p_b, + args.p_c, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_C); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..8db131738b72366664eefba58e0a7092b6c44f30 --- /dev/null +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + gemm_basic_args args; + args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); + args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); + args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); + args.kbatch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + float ave_time = gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm{MemBoundPipeline}"}; + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t batch_size = arg_parser.get_int("b"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, a_layout); + stride_B = f_get_default_stride(K, N, stride_B, b_layout); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout)); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout)); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + // TODO: add different init types + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_size, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::HostTensor c_m_n_gpu_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ck_tile::reference_gemm_gpu( + a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/03_gemm/script/run_full_test.sh b/example/ck_tile/03_gemm/script/run_full_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..2e2e7fdf90524bfdaedb240b4bc13c6556be9225 --- /dev/null +++ b/example/ck_tile/03_gemm/script/run_full_test.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_gemm executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +# get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +# run verification tests +example/ck_tile/03_gemm/script/smoke_test.sh + +# We do not have a performance benchmark for gemm yet. Will add it in the future. \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/smoke_test.sh b/example/ck_tile/03_gemm/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..4d9a64bf40dc38f19dd7612ce1dfb9cd14e5ab45 --- /dev/null +++ b/example/ck_tile/03_gemm/script/smoke_test.sh @@ -0,0 +1,35 @@ +#!/bin/bash +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_fp16_tests() { + for batch in 1 2; do + for m in 128 1024; do + for n in 128 2048; do + for k in 32 64; do + + $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done + done +} + +set -x + +run_fp16_tests + +set +x \ No newline at end of file diff --git a/example/ck_tile/04_img2col/CMakeLists.txt b/example/ck_tile/04_img2col/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3864c9ed9d32afd9f88303420493f51197ee1fa7 --- /dev/null +++ b/example/ck_tile/04_img2col/CMakeLists.txt @@ -0,0 +1,3 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp) diff --git a/example/ck_tile/04_img2col/README.md b/example/ck_tile/04_img2col/README.md new file mode 100644 index 0000000000000000000000000000000000000000..df5c51a9c04b608c8548529d634bbcff19617289 --- /dev/null +++ b/example/ck_tile/04_img2col/README.md @@ -0,0 +1,13 @@ +# Image to Column + +This folder contains example for Image to Column using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +make tile_example_img2col -j +``` +This will result in an executable `build/bin/tile_example_img2col` diff --git a/example/ck_tile/04_img2col/image_to_column.cpp b/example/ck_tile/04_img2col/image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6380cd299475c2d9913c490fe532581c534d5b23 --- /dev/null +++ b/example/ck_tile/04_img2col/image_to_column.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck_tile/host.hpp" +#include "image_to_column.hpp" + +// Host API implementation +template <> +float image_to_column(const image_to_column_traits& traits, + const image_to_column_args<2>& args, + const ck_tile::stream_config& stream_conf) +{ + if(traits.data_type.compare("fp16") == 0) + { + constexpr ck_tile::index_t NDimSpatial = 2; + constexpr ck_tile::index_t VectorSize = 8; + + using thread_tile = ck_tile::sequence<8, 8>; + using warp_tile = ck_tile::sequence<64, 64>; + using block_tile = ck_tile::sequence<128, 128>; + + using Shape = ck_tile::TileImageToColumnShape; + + using InDataType = ck_tile::half_t; + using OutDataType = ck_tile::half_t; + + using PipelineProblem = ck_tile::BlockImageToColumnProblem; + + using Kernel = ck_tile::ImageToColumn; + + auto kargs = Kernel::MakeKargs(args.p_in, + args.p_out, + args.G, + args.N, + args.C, + args.input_spatial_lengths, + args.filter_spatial_lengths, + args.output_spatial_lengths, + args.image_g_n_c_wis_strides, + args.gemm_g_m_k_strides, + args.conv_filter_strides, + args.conv_filter_dilations, + args.input_left_pads, + args.input_right_pads); + + const dim3 grids = Kernel::GridSize( + args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1], + args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C, + args.G); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = 2; + + float ave_time = ck_tile::launch_kernel( + stream_conf, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +int main(int argc, char* argv[]) +{ + constexpr ck_tile::index_t NDimSpatial = 2; + + ExecutionConfig config; + ck_tile::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + using InDataType = ck_tile::half_t; + using OutDataType = ck_tile::half_t; + using ImLayout = ck_tile::tensor_layout::convolution::NHWGC; + + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck_tile::long_index_t NHoWo = + N * std::accumulate(conv_params.output_spatial_lengths_.begin(), + std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial), + 1, + std::multiplies<>()); + + const ck_tile::long_index_t CYX = + C * std::accumulate(conv_params.filter_spatial_lengths_.begin(), + std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial), + 1, + std::multiplies<>()); + + const auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX}); + + // host verify + ck_tile::HostTensor in(in_desc); + ck_tile::HostTensor out_device(out_desc); + ck_tile::HostTensor out_host(out_desc); + + switch(config.init_method) + { + case 0: break; + case 1: ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(in); break; + default: ck_tile::FillUniformDistribution{-0.5, 0.5}(in); break; + } + + ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes()); + + in_device_buf.ToDevice(in.data()); + + image_to_column_traits traits{"fp16"}; + + image_to_column_args args{ + in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + ck_tile::to_array(conv_params.input_spatial_lengths_), + ck_tile::to_array(conv_params.filter_spatial_lengths_), + ck_tile::to_array(conv_params.output_spatial_lengths_), + ck_tile::to_array(in_desc.get_strides()), + ck_tile::to_array(out_desc.get_strides()), + ck_tile::to_array(conv_params.conv_filter_strides_), + ck_tile::to_array(conv_params.conv_filter_dilations_), + ck_tile::to_array(conv_params.input_left_pads_), + ck_tile::to_array(conv_params.input_right_pads_)}; + + float ave_time = + image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel}); + + std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(config.do_verification) + { + // reference + ck_tile::reference_im2col(in, out_host, conv_params); + + out_device_buf.FromDevice(out_device.data()); + pass = ck_tile::check_err(out_device, out_host); + + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + return !pass; +} diff --git a/example/ck_tile/04_img2col/image_to_column.hpp b/example/ck_tile/04_img2col/image_to_column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90484e08ec91d1bf06fb4d2060d56cc7dd1b6644 --- /dev/null +++ b/example/ck_tile/04_img2col/image_to_column.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/image_to_column.hpp" +#include + +#define DefaultConvParams \ + ck_tile::conv::ConvParam \ + { \ + 2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ + } + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck_tile::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = + ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} + +struct image_to_column_traits +{ + std::string data_type; +}; + +template +struct image_to_column_args +{ + const void* p_in; + void* p_out; + const ck_tile::long_index_t G; + const ck_tile::long_index_t N; + const ck_tile::long_index_t C; + const ck_tile::array input_spatial_lengths; + const ck_tile::array filter_spatial_lengths; + const ck_tile::array output_spatial_lengths; + const ck_tile::array image_g_n_c_wis_strides; + const ck_tile::array gemm_g_m_k_strides; + const ck_tile::array conv_filter_strides; + const ck_tile::array conv_filter_dilations; + const ck_tile::array input_left_pads; + const ck_tile::array input_right_pads; +}; + +// host API +template +float image_to_column(const image_to_column_traits&, + const image_to_column_args&, + const ck_tile::stream_config&); diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6caa38d50d054c1ae99e1bc7a389bb3860155e6b --- /dev/null +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -0,0 +1,19 @@ +set(EXAMPLE_REDUCE "tile_example_reduce") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") + +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..005541dc62bc428df3011df44234754ee663f9b2 --- /dev/null +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -0,0 +1,115 @@ +#include "ck_tile/host.hpp" +#include "reduce.hpp" +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor x_host({m, n}); + ck_tile::HostTensor y_host_ref({m}); + ck_tile::HostTensor y_host_dev({m}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using ReduceOp = ck_tile::ReduceOp::Add; + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using Vector = ck_tile::sequence<8, 8>; + + // cross warp-reduce + // using BlockWarps = ck_tile::sequence<2, 2>; + // using BlockTile = ck_tile::sequence<2, 1024>; + // using WarpTile = ck_tile::sequence<1, 512>; + // using Vector = ck_tile::sequence<1, 8>; + + constexpr ck_tile::index_t kBlockSize = 512; + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::Reduce2dShape; + using Porblem = + ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::Reduce; + + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n)); + + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_reduce( + x_host, y_host_ref, ReduceOp{}); + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, y_host_ref); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + // else if(data_type == "bf16") + // { + // return run(arg_parser) ? 0 : -2; + // } +} diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55e479591c7d34ebe331a01ffaf5423b3916b6d7 --- /dev/null +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WarpTile, // warp size, seq + typename Vector> // contiguous pixels(vector size) along seq +struct Reduce2dShape +{ + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WarpTile::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile::at(number<1>{}); + + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t BlockSize = + warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; + +template +struct Reduce2dProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using ReduceOp = ReduceOp_; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; +}; + +template +struct Reduce +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + +#if 0 + CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) + const + { + using S = typename Problem::BlockShape; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m = make_naive_tensor_view_packed( + p_y, make_tuple(M), number<1>{}); + + const auto iM = get_block_id() * S::Block_M; + + auto x_window = make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_m, make_tuple(number{}), {iM}); + + const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; }; + + const XDataType reduce_init_value = 0; + + constexpr auto reduce_dims = sequence<1>{}; + + auto y_compute = decltype(block_tile_reduce( + load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){}; + + set_tile(y_compute, reduce_init_value); + + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_tile_reduce(y_compute, x, reduce_dims, f_reduce); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_tile_reduce_sync(y_compute, f_reduce); + + store_tile(y_window, cast_tile(y_compute)); + } +#else + CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m = make_naive_tensor_view_packed( + p_y, make_tuple(M), number<1>{}); + + const auto iM = get_block_id() * S::Block_M; + + auto x_window = make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_m, make_tuple(number{}), {iM}); + + __shared__ char smem[Policy::template GetSmemSize()]; + + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + + auto reduce_func = typename Problem::ReduceOp{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorType = decltype(load_tile(x_window)); + auto y_compute = block_reduce2d.template MakeYBlockTile(); + set_tile(y_compute, reduce_func.template GetIdentityValue()); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_reduce2d(x, y_compute, reduce_func); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_compute, reduce_func); + block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func); + + store_tile(y_window, cast_tile(y_compute)); + } +#endif +}; + +} // namespace ck_tile diff --git a/example/ck_tile/06_permute/CMakeLists.txt b/example/ck_tile/06_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..327fceb685ed77e505ab10e547bdfde18305533d --- /dev/null +++ b/example/ck_tile/06_permute/CMakeLists.txt @@ -0,0 +1,13 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp) + +if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL) +# set(PERMUTE_USE_ALTERNATIVE_IMPL false) +set(PERMUTE_USE_ALTERNATIVE_IMPL true) +endif() +if(PERMUTE_USE_ALTERNATIVE_IMPL) +target_compile_options(tile_example_permute PRIVATE -DPERMUTE_USE_ALTERNATIVE_IMPL) +target_sources(tile_example_permute PRIVATE alternative_impl/matrix_core_swizzle.cpp) +endif() +# target_compile_options(tile_example_permute PRIVATE -v --save-temps -Wno-gnu-line-marker) diff --git a/example/ck_tile/06_permute/README.md b/example/ck_tile/06_permute/README.md new file mode 100644 index 0000000000000000000000000000000000000000..03bd810ff4ecd3206cf9e76932708b787360a5cf --- /dev/null +++ b/example/ck_tile/06_permute/README.md @@ -0,0 +1,46 @@ +# permute + +This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example. + + +``` +args: + -v weather do CPU validation or not (default:1) + -prec data type. fp16/bf16/fp32 (default:fp16) + -shape the shape of the input tensor (default:2,3,4) + -perm permute perm (default:2,1,0) +``` + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_permute -j +``` +This will result in an executable `build/bin/tile_example_permute` + + +## some examples +``` +# torch +x=torch.randn(2,3,4,6) +y=x.permute(0,3,2,1).contiguous() + +# ck_tile +./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1 +``` + +or you can try the smoke_test +``` +# in the root of ck_tile, after you build this example +sh example/ck_tile/06_permute/script/smoke_test.sh +``` + +### alternative implementation +we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail. +``` +# example +./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2 +./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2 +``` diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93c662a288fa804d72ad703152d8d83920ec920a --- /dev/null +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp @@ -0,0 +1,98 @@ +#include "matrix_core_swizzle.hpp" +#include "matrix_core_swizzle_kernel.hpp" + +float matrix_core_swizzle(matrix_core_swizzle_traits t, + matrix_core_swizzle_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + if(t.inst.compare("32x32x8") == 0) + { + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + } + else if(t.inst.compare("16x16x16") == 0) + { + constexpr int BLOCK_SIZE = 256; + constexpr int NPerBlock = 256; + constexpr int KPerBlock = 128; + constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16; + if(t.permute.compare("0,1,4,2,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,2,4,5,3,6") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + else if(t.permute.compare("0,1,3,4,2,5") == 0) + { + constexpr matrix_core_permute_style pstyle = + matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + using Kernel = + matrix_core_swizzle_kernel; + + auto k = Kernel(a); + float ave_time = ck_tile::launch_kernel(s, k); + + return ave_time; + } + } + } + return -1; +} diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e1ecdbbe64120e0351817608ed61cde6f1fdd6fc --- /dev/null +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "matrix_core_swizzle_kernel.hpp" +#include + +struct matrix_core_swizzle_traits +{ + std::string data_type; // fp16 only + std::string inst; // 32x32x8, 16x16x16 + std::string permute; // +}; + +using matrix_core_swizzle_args = matrix_core_swizzle_host_args; + +// host API +float matrix_core_swizzle(matrix_core_swizzle_traits, + matrix_core_swizzle_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60ac103ec3d1f23a535813bdebd0048f20e854d1 --- /dev/null +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +// if set to 1, slightly more instructions generated to calculate address +#ifndef MERGE_2D_013425 +#define MERGE_2D_013425 0 +#endif + +enum class matrix_core_inst_enum +{ + MFMA_32x32x8_F16 = 0, + MFMA_16x16x16_F16 = 1, +}; + +namespace detail { +template +struct to_warp_gemm; + +template <> +struct to_warp_gemm +{ + using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8; +}; + +template <> +struct to_warp_gemm +{ + using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16; +}; +} // namespace detail +template +using to_warp_gemm_t = typename detail::to_warp_gemm::type; + +// TODO: in below permute pattern, the last 3 dim is within wave +enum class matrix_core_permute_style +{ + permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 + permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 + permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, +}; + +// assume this is B matrix, originally we have batch*n*k +// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 +// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4) +// +// 4(waves) 32(mfma_m lane) +// | | +// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading) +// nr kr | +// nr 4 32 kr 2 8 2(klane) +// +// permute: 0,1,4,2,5,3,6 +// or +// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading) +// permute: 0,1,2,4,5,3,6 +// +// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling +// for simplicity, only consider n/k is multiple of block-size + +// independend host arg with no template +struct matrix_core_swizzle_host_args +{ + const void* p_src; + void* p_dst; + int32_t batch; + int32_t n; + int32_t k; +}; + +// NOTE: this kernel could follow the style of generic permute kernel +// but here we pass in fixed layout as template arg and generate different kernel instance +// purposely +template +struct matrix_core_swizzle_kernel +{ + using karg = matrix_core_swizzle_host_args; + using harg = matrix_core_swizzle_host_args; + + static constexpr int BLOCK_SIZE = BLOCK_SIZE_; + static constexpr int WavesPerBlock_N = 4; + static constexpr int WavesPerBlock_K = 1; + static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE); + static constexpr int NPerBlock = NPerBlock_; + static constexpr int KPerBlock = KPerBlock_; + static constexpr matrix_core_permute_style pstyle = pstyle_; + static constexpr matrix_core_inst_enum Inst = Inst_; + + static constexpr ck_tile::index_t Alignment = 8; + karg a; + dim3 grids; + + using WarpGemm = to_warp_gemm_t; + + __host__ matrix_core_swizzle_kernel(harg h) + { + a = h; + ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock; + ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock; + grids = dim3(ks, ns, h.batch); + } + + __host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; } + + __host__ void operator()(const ck_tile::stream_config& s) const + { + ck_tile::kentry<<>>(a); + } + + struct kernel + { + __device__ static constexpr auto get_src_dist() + { + using namespace ck_tile; + constexpr index_t K2 = Alignment; + constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t N1 = BLOCK_SIZE / get_warp_size(); + + static_assert(NPerBlock % (N1 * N2) == 0); + static_assert(KPerBlock % (K1 * K2) == 0); + + constexpr index_t K0 = KPerBlock / (K1 * K2); + constexpr index_t N0 = NPerBlock / (N1 * N2); + + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<5, 3>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 4, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + __device__ static constexpr auto get_dst_dist() + { + using namespace ck_tile; + constexpr index_t K2 = Alignment; + constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t N1 = BLOCK_SIZE / get_warp_size(); + + static_assert(NPerBlock % (N1 * N2) == 0); + static_assert(KPerBlock % (K1 * K2) == 0); + + constexpr index_t K0 = KPerBlock / (K1 * K2); + constexpr index_t N0 = NPerBlock / (N1 * N2); + + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<4, 5>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 2, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + // clang-format off + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>,// 0 + // 1 2 3 4 5 6 + tuple, sequence, sequence, sequence, sequence, sequence>, + + // N1 K1 N2 + tuple, sequence<4, 5>>, + tuple, sequence<0, 0>>, + + // N0 K0 K2 + sequence<1, 3, 6>, + sequence<0, 0, 0>>{}); + // clang-format on + } + else + { + // clang-format off + // permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten + constexpr index_t Kv = Alignment; + constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + + static_assert(KPerBlock % (K1 * K2) == 0); + constexpr index_t Nr = NPerBlock / Nw; + constexpr index_t Kr = KPerBlock / (Kv * Kw); + + constexpr index_t Nr_p = WavesPerBlock_N; + constexpr index_t Kr_p = WavesPerBlock_K; + constexpr index_t Nr_y = Nr / Nr_p; + constexpr index_t Kr_y = Kr / Kr_p; + + return make_static_tile_distribution( +#if MERGE_2D_013425 + tile_distribution_encoding< + sequence<1>,// 0 R + // major 1 2 + // minor 0 1 2 0 1 2 3 + tuple, sequence>, // H + + // Nr_p, Kr_p Kw Nw + tuple, sequence<2, 1>>, // p major + tuple, sequence<2, 2>>, // p minor + + // Nr_y Kr_y Kv + sequence<1, 2, 2>, // Y major + sequence<0, 0, 3>>{}); // y minor +#else + tile_distribution_encoding< + sequence<1>,// 0 R + // major 1 2 3 + // minor 0 1 0 1 0 1 2 + tuple, sequence, sequence>, // H + + // Nr_p, Kr_p Kw Nw + tuple, sequence<3, 3>>, // p major + tuple, sequence<0, 1>>, // p minor + + // Nr_y Kr_y Kv + sequence<1, 2, 3>, // Y major + sequence<0, 0, 2>>{}); // y minor +#endif + // clang-format on + } + } + + __device__ void operator()(karg a_) + { + using namespace ck_tile; + index_t i_k = blockIdx.x; + index_t i_n = blockIdx.y; + index_t i_b = blockIdx.z; + + constexpr index_t k2 = Alignment; + constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t n1 = BLOCK_SIZE / get_warp_size(); + const index_t k0 = a_.k / (k1 * k2); + const index_t n0 = a_.n / (n1 * n2); + + constexpr index_t k2_tile = Alignment; + constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size(); + constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile); + constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile); + + const fp16_t* p_src = reinterpret_cast(a_.p_src) + i_b * a_.k * a_.n; + fp16_t* p_dst = reinterpret_cast(a_.p_dst) + i_b * a_.k * a_.n; + + const auto src_view = [&]() { + const auto tmp = make_naive_tensor_view_packed( + p_src, + make_tuple(n0, n1, n2, k0, k1, k2), + number{}); // control vector load + return tmp; + }(); + + const auto src_window = make_tile_window(src_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0}, + get_src_dist()); + + auto dst_view = [&]() { + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(n0, k0, n1, k1, n2, k2), + number{}); // control vector load + return tmp; + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(n0, n1, k0, k1, n2, k2), + number{}); // control vector load + return tmp; + } + else + { +#if MERGE_2D_013425 + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + // constexpr index_t waveflatten = kw*nw*kv; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(nr, kr, number{}, number{}, number{}), + number{}); // control vector load + auto tmp_1 = transform_tensor_view( + tmp, + make_tuple( + make_merge_transform(make_tuple(nr, number{})), + make_merge_transform(make_tuple(kr, number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return tmp_1; +#else + // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t waveflatten = kw * nw * kv; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; + auto tmp = make_naive_tensor_view_packed( + p_dst, + make_tuple(nr, kr, waveflatten), + number{}); // control vector load + return tmp; +#endif + } + }(); + + auto dst_window = [&]() { + if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2) + { + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0}, + get_dst_dist()); + } + else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2) + { + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + {i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0}, + get_dst_dist()); + } + else + { +#if MERGE_2D_013425 + // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + return make_tile_window(dst_view, + make_tuple(number{}, number{}), + {i_n * NPerBlock, i_k * KPerBlock}, + get_dst_dist()); +#else + // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t waveflatten_tile = kw * nw * kv; + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); + return make_tile_window(dst_view, + make_tuple(number{}, + number{}, + number{}), + {i_n * nr_tile, i_k * kr_tile, 0}, + get_dst_dist()); +#endif + } + }(); + + // actual load store + auto src_tile = load_tile(src_window); + + // now we only swap the distribution from src to dst, no extra movement occurs + auto dst_tile = make_static_distributed_tensor(get_dst_dist()); + dst_tile.get_thread_buffer() = src_tile.get_thread_buffer(); + + // final store + store_tile(dst_window, dst_tile); + } + }; +}; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp new file mode 100644 index 0000000000000000000000000000000000000000..af95b64e69e23f55cbed079f0733743d4b160993 --- /dev/null +++ b/example/ck_tile/06_permute/permute.cpp @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "permute.hpp" +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL +#include "alternative_impl/matrix_core_swizzle.hpp" +#endif + +namespace detail { +template +struct to_integer_type; + +template <> +struct to_integer_type<4> +{ + using type = int32_t; +}; +template <> +struct to_integer_type<2> +{ + using type = int16_t; +}; +template <> +struct to_integer_type<1> +{ + using type = int8_t; +}; +} // namespace detail + +template +using to_integer_type = typename detail::to_integer_type::type; + +// host API (shoule come from codegen) +float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp8") == 0) + { + using DataType = ck_tile::fp8_t; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + else if(t.data_type.compare("fp16") == 0) + { + using DataType = ck_tile::half_t; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + else if(t.data_type.compare("fp32") == 0) + { + using DataType = float; + using PipelineProblem = ck_tile::GenericPermuteProblem; + using Kernel = ck_tile::GenericPermute; + + auto kargs = Kernel::MakeKargs(a); + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)") + .insert("shape", "2,3,4", "the shape of the input tensor") + .insert("perm", "2,1,0", "permute perm") + .insert("kname", "0", "t to 1 will print kernel name") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +// "1,2,3,4" -> vector{1,2,3,4} +std::vector decode_vec(std::string q_val) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + std::string::size_type pos = 0; + std::vector v; + while(true) + { + auto found = q_val.find(',', pos); + ck_tile::index_t n = + _S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos)); + v.push_back(n); + if(found == std::string::npos) + { + break; + } + pos = found + 1; + } + return v; +#undef _S2I_ +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + auto shape = decode_vec(arg_parser.get_str("shape")); + auto perm = decode_vec(arg_parser.get_str("perm")); + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + int seed = arg_parser.get_int("seed"); + + assert(shape.size() == perm.size()); + ck_tile::index_t rank = perm.size(); + if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks) + { + printf("rank %d permute is not support yet\n", rank); + return false; + } + + ck_tile::HostTensor x(shape); + ck_tile::FillUniformDistributionIntegerValue{-15, 15, seed}(x); + + std::vector y_shape = [&]() { + std::vector tmp(rank, 0); + // std::cout << "@@@@" << tmp << std::endl; + for(int i = 0; i < static_cast(rank); i++) + { + // std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" << + // static_cast(rank) + // << std::endl; + tmp[i] = shape[perm[i]]; + } + // std::cout << "@@@" << tmp << std::endl; + return tmp; + }(); + + ck_tile::HostTensor y(y_shape); + + ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x.data()); + + std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm + << std::flush; + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat}; + float ave_time = 0.f; + auto run_permute = [&]() { + permute_traits t; + t.data_type = data_type; + + permute_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.rank = rank; + std::copy(shape.begin(), shape.end(), a.shape); + std::copy(perm.begin(), perm.end(), a.perm); + + return permute(t, a, stream_config); + }; +#ifdef PERMUTE_USE_ALTERNATIVE_IMPL + // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 + if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") || + arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") || + arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))) + { + if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) + { + // permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + matrix_core_swizzle_traits t; + t.data_type = data_type; + t.permute = arg_parser.get_str("perm"); + + matrix_core_swizzle_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.batch = shape[0]; + + auto nr = shape[1]; + auto nw = shape[2]; + auto kr = shape[3]; + auto kw = shape[4]; + auto kv = shape[5]; + a.n = nr * nw; + a.k = kr * kw * kv; + if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0) + { + t.inst = "16x16x16"; + std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0) + { + t.inst = "32x32x8"; + std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else + { + ave_time = run_permute(); + } + } + else + { + matrix_core_swizzle_traits t; + t.data_type = data_type; + t.permute = arg_parser.get_str("perm"); + + matrix_core_swizzle_args a; + a.p_src = x_buf.GetDeviceBuffer(); + a.p_dst = y_buf.GetDeviceBuffer(); + a.batch = shape[0]; + a.n = shape[1] * shape[2] * shape[3]; + a.k = shape[4] * shape[5] * shape[6]; + if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 && + shape[4] % 8 == 0 && shape[1] % 2 == 0) + { + // 32x32x8 inst + // perm=0,1,4,2,5,3,6 + // y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8) + // shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8) + + t.inst = "32x32x8"; + std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 && + shape[4] % 4 == 0 && shape[1] % 4 == 0) + { + // 16x16x16 inst + // perm=0,1,4,2,5,3,6 + // y_shape=*,4x,4x,4,4,16,8 + // shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8) + t.inst = "16x16x16"; + std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; + + ave_time = matrix_core_swizzle(t, a, stream_config); + } + else + { + ave_time = run_permute(); + } + } + } + else +#endif + { + ave_time = run_permute(); + } + std::cout << ", time:" << ave_time << "ms" << std::flush; + + bool pass = true; + if(do_validation) + { + reference_permute(x, y, perm); +#if 0 + if constexpr (std::is_same_v){ + // using itype = to_integer_type; + fflush(stdout); + for(int zz = 0; zz < static_cast(x.get_element_size()); zz++ ) { + printf("%3.0f ", x.mData[zz]); + } + printf("->\n"); + for(int zz = 0; zz < static_cast(x.get_element_size()); zz++ ) { + printf("%3.0f ", y.mData[zz]); + } + fflush(stdout); + } +#endif + ck_tile::HostTensor y_dev(y.get_lengths()); + + y_buf.FromDevice(y_dev.data()); + + pass = std::equal( + y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) { + using itype = to_integer_type; + itype i_d = ck_tile::bit_cast(d); + itype i_h = ck_tile::bit_cast(h); + return i_d == i_h; + }); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp32") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/06_permute/permute.hpp b/example/ck_tile/06_permute/permute.hpp new file mode 100644 index 0000000000000000000000000000000000000000..304da4dc9781097061dabf06b5f02dbab8bfe9c4 --- /dev/null +++ b/example/ck_tile/06_permute/permute.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/permute.hpp" +#include + +struct permute_traits +{ + std::string data_type; +}; + +using permute_args = ck_tile::GenericPermuteHostArgs; + +// host API +float permute(permute_traits, permute_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/06_permute/script/smoke_test.sh b/example/ck_tile/06_permute/script/smoke_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..793e52d2bb7e1ed004c464e261355937f344dba5 --- /dev/null +++ b/example/ck_tile/06_permute/script/smoke_test.sh @@ -0,0 +1,34 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_permute +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 +if [ $# -ge 1 ] ; then + set -x +fi + +$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS +$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS +$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS + +echo "------------------------------------------------------------------" + +for prec in "fp8" "fp16" "fp32" ; do + +$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS +$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS +$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS +$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS +$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS +$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS +$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS +$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS +echo "------------------------------------------------------------------" +done diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b43b9897922f16592732b8fb4083d144216af46b --- /dev/null +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp) +target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS}) diff --git a/example/ck_tile/09_topk_softmax/README.md b/example/ck_tile/09_topk_softmax/README.md new file mode 100644 index 0000000000000000000000000000000000000000..10430129003b66b94e328c6b19a6bc9b431bd535 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/README.md @@ -0,0 +1,28 @@ +# topk-softmax + +This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_topk_softmax -j +``` +This will result in an executable `build/bin/tile_example_topk_softmax` + +## example +``` +args: + -v weather do CPU validation or not (default:1) + -pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16) + -pr_w output weight data type(currently only fp32 supported now) (default:fp32) + -t number of input tokens (default:32) + -e number of experts (default:8) + -k topk (default:2) + -st_i row stride of input, -1 means same as experts (default:-1) + -st_o row stride of output/indices, -1 means same as topk (default:-1) + -seed seed to be used, -1 means random every time (default:-1) + -kname when set to 1 it will print kernel name (default:0) + +``` diff --git a/example/ck_tile/09_topk_softmax/script/smoke_test.sh b/example/ck_tile/09_topk_softmax/script/smoke_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..646f5889f708483d2b083b721ece9875297cce77 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/script/smoke_test.sh @@ -0,0 +1,22 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_topk_softmax + +for pr_i in "fp16" "bf16" ; do +$EXE -pr_i=$pr_i -t=80 -e=17 +$EXE -pr_i=$pr_i -t=111 -e=117 +$EXE -pr_i=$pr_i -t=1000 -e=55 +$EXE -pr_i=$pr_i -t=99 -e=180 +$EXE -pr_i=$pr_i -t=175 -e=64 -k=8 +$EXE -pr_i=$pr_i -t=65 -e=8 -k=2 +$EXE -pr_i=$pr_i -t=1 -e=25 +$EXE -pr_i=$pr_i -t=31 -e=19 -k=15 +$EXE -pr_i=$pr_i -t=81 -e=37 -k=7 +$EXE -pr_i=$pr_i -t=199 -e=128 -k=13 +$EXE -pr_i=$pr_i -t=23 -e=1 -k=1 +$EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31 +$EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12 +$EXE -pr_i=$pr_i -t=1 -e=1 -k=1 +$EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5 +$EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17 +done diff --git a/example/ck_tile/09_topk_softmax/topk_softmax.cpp b/example/ck_tile/09_topk_softmax/topk_softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6fc25631fdef01206de0cc5c4db5a89149990386 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "topk_softmax_api.hpp" + +#if 0 +template +void dump_host_tensor_2d(const ck_tile::HostTensor& x) +{ + auto len = x.get_lengths(); + assert(len.size() == 2); + std::cout << "["; + for(size_t i = 0; i < len[0]; i++) + { + std::cout << i << ": ["; + for(size_t j = 0; j < len[1]; j++) + { + if constexpr(std::is_same_v) + { + auto v = ck_tile::type_convert(x(i, j)); + + std::cout << v; + if(j != len[1] - 1) + std::cout << ","; + } + else + { + std::cout << x(i, j) << " "; + } + } + std::cout << "]"; + if(i != len[0] - 1) + std::cout << ","; + else + std::cout << "]"; + std::cout << std::endl; + } + std::cout << "--------------------" << std::endl; +} +#endif + +// CPU reference +template +auto reference_topk_softmax(const ck_tile::HostTensor& x, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + auto y = reference_softmax(x, dim); + + auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted); + + return ck_tile::make_tuple(y_values, y_indices); +} + +template +auto reference_topk_softmax(const ck_tile::HostTensor& x, + ck_tile::HostTensor& y_values, + ck_tile::HostTensor& y_indices, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + auto y = reference_softmax(x, dim); + reference_topk(y, y_values, y_indices, k, dim, largest, sorted); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)") + .insert("t", "32", "number of input tokens") + .insert("e", "8", "number of experts") + .insert("k", "2", "topk") + .insert("st_i", "-1", "row stride of input, -1 means same as experts") + .insert("st_o", "-1", "row stride of output/indices, -1 means same as topk") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "when set to 1 it will print kernel name") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool test_topk_softmax(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string input_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + int tokens = args.get_int("t"); + int experts = args.get_int("e"); + int topk = args.get_int("k"); + int seed = args.get_int("seed"); + int stride_input = args.get_int("st_i"); + int stride_output = args.get_int("st_o"); + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + + if(stride_input < 0) + { + stride_input = experts; + } + if(stride_output < 0) + { + stride_output = topk; + } + assert(stride_input >= experts); + assert(stride_output >= topk); + + if(seed < 0) + { + seed = std::time(nullptr); + } + + if(topk > experts) + { + printf("topk:%d value should be smaller than, or equal to number of experts:%d\n", + topk, + experts); + return false; + } + + // tokens already considered batch size + ck_tile::HostTensor x_host({tokens, experts}, {stride_input, 1}); + ck_tile::HostTensor value_host({tokens, topk}, {stride_output, 1}); + ck_tile::HostTensor index_host({tokens, topk}, {stride_output, 1}); + + { + // random require per-row unique + auto rand_gen = ck_tile::FillUniformDistribution_Unique{ + -5.f, 5.f, static_cast(seed)}; + + for(int i_t = 0; i_t < tokens; i_t++) + { + ck_tile::HostTensor x_row({experts}); + rand_gen(x_row); + std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input); + rand_gen.clear(); + } + } + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + topk_softmax_trait trait{input_prec, weight_prec, experts}; + + topk_softmax_kargs karg{x_dev.GetDeviceBuffer(), + value_dev.GetDeviceBuffer(), + index_dev.GetDeviceBuffer(), + tokens, + experts, + topk, + stride_input, + stride_output}; + + ck_tile::stream_config sc{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + warmup, + repeat}; + auto ms = topk_softmax(trait, karg, sc); + printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ", + input_prec.c_str(), + weight_prec.c_str(), + tokens, + experts, + topk, + stride_input, + stride_output, + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + if(ms < 0) + { + return false; + } + + value_dev.FromDevice(value_host.data()); + index_dev.FromDevice(index_host.data()); + + bool rtn = true; + if(validate) + { + ck_tile::HostTensor value_ref({tokens, topk}, {stride_output, 1}); + ck_tile::HostTensor index_ref({tokens, topk}, {stride_output, 1}); + + reference_topk_softmax( + x_host, value_ref, index_ref, topk); + + auto [rtol, atol] = get_elimit(""); + for(int i_t = 0; i_t < tokens; i_t++) + { + auto s_begin = std::vector{static_cast(i_t), static_cast(0)}; + auto s_end = + std::vector{static_cast(i_t + 1), static_cast(topk)}; + auto s_value_host = value_host.slice(s_begin, s_end); + auto s_value_ref = value_ref.slice(s_begin, s_end); + rtn &= ck_tile::check_err(s_value_host, + s_value_ref, + std::string("[") + std::to_string(i_t) + + std::string("] Value Error:"), + rtol, + atol); + auto s_index_host = index_host.slice(s_begin, s_end); + auto s_index_ref = index_ref.slice(s_begin, s_end); + rtn &= ck_tile::check_err(s_index_host, + s_index_ref, + std::string("[") + std::to_string(i_t) + + std::string("] Index Error:"), + rtol, + atol); + } + } + + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string input_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + + bool r = true; + if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0) + { + r &= test_topk_softmax(args); + } + else if(input_prec.compare("bf16") == 0 && weight_prec.compare("fp32") == 0) + { + r &= test_topk_softmax(args); + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..249a307b81c57f209accb0567382b54f29f75bdd --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "topk_softmax_api.hpp" + +#define TOPK_SOFTMAX_DISPATCH(experts_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + using ts_problem = ck_tile:: \ + TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + constexpr dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); \ + \ + return ave_time; + +float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) +{ + if(t.input_type == "fp16" && t.weight_type == "fp32") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; +#if 1 + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192) + } +#else + if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128) + } +#endif + } + else if(t.input_type == "bf16" && t.weight_type == "fp32") + { +#if 1 + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192) + } +#endif + } + return -1; +} diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65651efa4d42bb381c9a9c10d963b8ca3c5352d2 --- /dev/null +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/topk_softmax.hpp" +#include + +struct topk_softmax_trait +{ + std::string input_type; + std::string weight_type; // currently always float + int experts; +}; + +struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs +{ +}; + +float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s); diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a3ff8fdf4595715a1620e3a8f46de870ecd06bb1 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -0,0 +1,25 @@ +set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_RMSNORM2D_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) +target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) + +set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd") +add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp) +target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/10_rmsnorm2d/README.md b/example/ck_tile/10_rmsnorm2d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c067496477167cb89c1048a342a27e0c5d40ac72 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/README.md @@ -0,0 +1,22 @@ +# Rmsnorm2D forward + +This folder contains example for Rmsnorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_rmsnorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_rmsnorm2d_fwd` + +## cmdline +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..34df7b74fa3c710becc10670b54c76f910317781 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -0,0 +1,165 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/rmsnorm2d.hpp" +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using XDataType = DataType; + using YDataType = DataType; + using GammaDataType = DataType; + using InvRmsDataType = ck_tile::null_type; + + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor invRms_host_ref({m}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + constexpr bool kTwoPass = true; + + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::Generic2dBlockShape; + using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::Rmsnorm2dFwd; + + ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), + nullptr, + epsilon, + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3); + if(stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, + y_host_dev.begin() + i_r * stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, + y_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b8697183f96bc6b4421ecbd8b26353ae7c00941e --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" + +template +using trait_ = rmsnorm2d_fwd_traits_; + +template +float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/, + rmsnorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd rms 2p + if(a.n <= 64) { + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + return r; + // clang-format on +} + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) +{ + + if(t.data_type.compare("fp16") == 0) + { + return rmsnorm2d_fwd_b16_(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return rmsnorm2d_fwd_b16_(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e2a35f9e8fb496f21832c222b9f68042a63a21d --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +#if 0 +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +template float rmsnorm2d_fwd_>(const S&, A); +#endif + +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8c734806e18b4782092f7a0e5cc460b3abc158d4 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9222001433464eebcf1e20911b6b06b85c117270 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed33c849232cc95d251240f7d146678273bd4e52 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b753bbc3458d3194f0cc6962b51d499bd331848b --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27cb9bdf3d47dc34909e0c1333c5daa32e640274 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23afb5672b4b109fa9d2b89abec5318766540e92 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b428f58051bae64bd0497a991a7fc265ad96ec3d --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3001106697dafa0d521672af936e17b1cf2fddac --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e9c8d6a1d444b0b085bbfc20575352a1d766b2a3 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15198eebe67258266529ed81a7c8f5bf16d48ca2 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +#if 0 +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +template float rmsnorm2d_fwd_>(const S&, A); +#endif + +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ac85fa9b5a68246a5af7c039b4131a3b35c9c56 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10e8fafc2f4c780ff622ee501f685671fc7dd25a --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e1a80bf64b3598864765454fd46f8ce9c9c6eb0 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45e56a92b8886ffe3b07189646aad12caeffb359 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35401f6f82b50c40137599456250c13c46092cb2 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e3700fad3ab61ff669d2950f96578170a01320e --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cdc4d00bd2336a1e55c900cd078dd8cde52ac11b --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec80c2ee4a93f999be3960ba7154a16d8992f302 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddfc5a54e8e6ea3129804f28a56aed98b5432f67 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8f6ff84b643d2b7fafebc5b0a9ef6ade1ebdbd23 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp @@ -0,0 +1,65 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = rmsnorm2d_fwd_args; + +template +using trait_ = rmsnorm2d_fwd_traits_; + +template +float rmsnorm2d_fwd_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = + ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, + typename RmsnormTypeConfig::GammaDataType, + typename RmsnormTypeConfig::ComputeDataType, + typename RmsnormTypeConfig::YDataType, + typename RmsnormTypeConfig::InvRmsDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveInvRms, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Rmsnorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..698a8b43eb9329f5bfb0c61b78cea98a0cf07f5f --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -0,0 +1,179 @@ +#include "ck_tile/host.hpp" +#include "rmsnorm2d_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = RmsnormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + + using InvRmsDataType = + std::conditional_t; + + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor invRms_host_ref({m}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + rmsnorm2d_fwd_traits traits{data_type, SaveRms}; + + rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), + nullptr, + epsilon, + m, + n, + stride}; + + float ave_time = rmsnorm2d_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = + sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, + y_host_dev.begin() + i_r * stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, + y_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + int save_rms = arg_parser.get_int("save_rms"); + if(data_type == "fp16" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16" && !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b4d429d46f4a1a418527bf515ccd5e06e6243352 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/rmsnorm2d.hpp" +#include + +template +struct RmsnormTypeConfig; + +template <> +struct RmsnormTypeConfig +{ + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using InvRmsDataType = ck_tile::half_t; + using ComputeDataType = float; +}; + +template <> +struct RmsnormTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using YDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using InvRmsDataType = ck_tile::bf16_t; + using ComputeDataType = float; +}; + +// runtime args +struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct rmsnorm2d_fwd_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); + +// This is the public API, will be generated by script +struct rmsnorm2d_fwd_traits +{ + std::string data_type; + bool save_rms; +}; + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..7b9d0820fd4bf5de41c049b97d005c07cf44c124 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh @@ -0,0 +1,37 @@ +#!/bin/sh +EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..758d6de54680cc303068044c5e0fc8d27baba4b2 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -0,0 +1,30 @@ +#!/bin/sh +EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b0c3cef7a991bf0b208e84218763211c8426dd5 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt @@ -0,0 +1,25 @@ +set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp) +target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + +set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd") +add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp) +target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md new file mode 100644 index 0000000000000000000000000000000000000000..960369b78d795774d2682ed9cb97ddda4ef6429c --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md @@ -0,0 +1,22 @@ +# Add + Rmsnorm2D + rowwise dynamic quantization forward + +This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization forward using ck_tile tile-programming implementation. Rdquant is short for rowwise dynamic quantization here. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_add_rmsnorm2d_rdquant_fwd -j +``` +This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd` + +## cmdline +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43bc9a6cfe65edff55c0d20863c3b0cf32346521 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -0,0 +1,279 @@ +#include "ck_tile/host.hpp" +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = AddRmsnormRdquantTypeConfig; + + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XDataType = typename TypeConfig::XDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + add_rmsnorm2d_rdquant_fwd_traits traits{data_type, SaveX}; + + add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + float ave_time = add_rmsnorm2d_rdquant_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n + + sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m + + sizeof(QYDataType) * m * n; + + if constexpr(SaveX) + num_byte += sizeof(XDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + using InvRmsDataType = DataType; + + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); + + x_buf.FromDevice(x_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + ck_tile::HostTensor y_host({m, n}); + // Rmsnorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + int save_x = arg_parser.get_int("save_x"); + if(data_type == "fp16" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16" && !save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && !save_x) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..443b9b10248d903b90a797008dbd0521d8dfcf27 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +template +struct AddRmsnormRdquantTypeConfig; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +// runtime args +struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct add_rmsnorm2d_rdquant_fwd_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kThreePass = kThreePass_; +}; + +template +float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a); + +// This is the public API, will be generated by script +struct add_rmsnorm2d_rdquant_fwd_traits +{ + std::string data_type; + bool save_x; +}; + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits, + add_rmsnorm2d_rdquant_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ada4c6f2da22cbbd24c3e7a50d1ef6566f2d35af --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -0,0 +1,280 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using ADataType = DataType; + using BDataType = DataType; + using GammaDataType = DataType; + using XDataType = DataType; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + constexpr bool kThreePass = true; + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<4, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::Generic2dBlockShape; + using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + using InvRmsDataType = DataType; + + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); + + x_buf.FromDevice(x_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + ck_tile::HostTensor y_host({m, n}); + // Rmsnorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..966c5bd02f240ca5e26d90ab16d70e0fb48f36ea --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd x 3p + if(a.n <= 64) { + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + return r; + // clang-format on +} + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + + // Only support instance of save_x == true for now + assert(t.save_x); + if(t.data_type.compare("fp16") == 0) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5495e3c9aba7673ad9c53856071772ca1d3d1978 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8bbfdc85890e06a13f903759150a07fb3f51bc13 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..381a11fc806e14e69d34bfe305449e8b83a6f5f1 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2fefac69342b0fe3697d83f67e62d1f49f654477 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..263713bbc7a27e225949699672424a45fb0dc325 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c62c596fabf60d5623bdcc95c9078411d5218519 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4951f6ab9de2a3fb8feafd20077ac19c71b8af6 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c7ee48e8e5d0a4a1a8a7424bee281cc73534397 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8659dc82b3052de93cc62d24c7e29ab74cf55136 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f15f11b47dcfc7a643251a67fb9673444de0a61 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ffdacbdcd01e8d7c15008b504f4b2d611b9ddf4 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..355109965166185f79865bfd24ae51c300bd2c40 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d4d0474c2709a83344e127b7ab629d9c99820b41 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2cb300eda6b7cd893d70dbaaedf8cd50602b4109 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb0ceb4c587be0aa5f1139e9df94e0ef600a23cc --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a241a3c93fba6238783e6f56ccec29c9bc7b983 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3094679f95cf2207568e0edb6a7d0e300bddfc9 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..919bc177e85fd0487c9a34646bfa3595eb8b7a0f --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8a44f5e00fd098d237c1c21349ea10f5cad97ab9 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5c4f05ec3c690fe24676b74fc887102ced3592cb --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6baaad471a6ca486ccf533f03c9b500f5bbccb01 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -0,0 +1,67 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = add_rmsnorm2d_rdquant_fwd_args; + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< + typename AddRmsnormRdquantTypeConfig::ADataType, + typename AddRmsnormRdquantTypeConfig::BDataType, + typename AddRmsnormRdquantTypeConfig::GammaDataType, + typename AddRmsnormRdquantTypeConfig::ComputeDataType, + typename AddRmsnormRdquantTypeConfig::XDataType, + typename AddRmsnormRdquantTypeConfig::YScaleDataType, + typename AddRmsnormRdquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveX, + Traits_::kThreePass>; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..d02b0bab337a13a2e754960bb2b96509c4abbcbc --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh @@ -0,0 +1,37 @@ +#!/bin/sh +EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)" + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b60f5fcf200c3a0bd717b48fca7500320a79418f --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh @@ -0,0 +1,30 @@ +#!/bin/sh +EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)" + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..09a56c6dabf6fb158ec5108e210a6b99e5b104db --- /dev/null +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -0,0 +1,24 @@ +function (add_smoothquant_example TARGET_NAME MAIN_SRC) + message("adding ${TARGET_NAME}") + # not using add_example_executable() to add target, since we don't want this to have + # to be included in "make all/install/check" + add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + + foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) + endforeach() + + target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + + set(COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + + target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) +endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) + +file(GLOB INSTANCE_SRCS instances/*.cpp) + +add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS}) +add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) diff --git a/example/ck_tile/12_smoothquant/README.md b/example/ck_tile/12_smoothquant/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d6b815f8cf31961703dd7960308986c7126dbb11 --- /dev/null +++ b/example/ck_tile/12_smoothquant/README.md @@ -0,0 +1,21 @@ +# smoothquant + +This folder contains example for smoothquant using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_smoothquant -j +``` +This will result in an executable `build/bin/tile_smoothquant` + +## cmdline +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a26eb6a77904f47a810f53271b0b3657fc474ee --- /dev/null +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -0,0 +1,237 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/smoothquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using XDataType = DataType; + using XScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor xscale_host({n}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + xscale_buf.ToDevice(xscale_host.data()); + + constexpr bool kTwoPass = true; + + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::Generic2dBlockShape; + using Problem = ck_tile::SmoothquantPipelineProblem; + + using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass; + using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::Smoothquant; + + ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), + xscale_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + ck_tile::HostTensor y_host({m, n}, {stride, 1}); + // smooth outlier + { + auto f = [&](auto n_) { + auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + + for(int m_ = 0; m_ < m; ++m_) + { + auto v_x = ck_tile::type_convert(x_host(m_, n_)); + y_host(m_, n_) = v_x * v_xscale; + } + }; + + ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + std::thread::hardware_concurrency()); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + /*else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + }*/ + + return -3; +} diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b25361da2ffd5834a7d1d072dfa282ce6c29f68f --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +template float smoothquant_>(const S&, A); +#endif + +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a332fe410dfb6ef9db62049f2482db08ada0907 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bdf5804e43affb5bbd7164b98ef860fad5016be5 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..774c977f2e710fe4b8b74d3d848925b5042e117f --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c571ef443e8d973fd4a07fb1db95bdb91cfa4250 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80e4b3a296d303a5f23c3bacc05dd6053150fa2d --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f776a6e4653fb07bf2bdc59be8a92e3f0689049 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12bc90b66966f45c8dd6444e8d0c67ccd6ef2c2a --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cee186063a0bce900a601d4033fe5ffe7ffc0fc --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aca7f7eb4eb92d93db06329287b549dbcd85b1ed --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be5fecaca18ee05ea64bc7eabdf65c6780f4d989 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +template float smoothquant_>(const S&, A); +#endif + +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..59fe1487502c06569d38ac95b84c10049cf5e221 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3710a6ab4c1382441ca2dfee2adaf79899635d4 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b1bca7aa47bde9e726d990c8c1ee77df4116c98 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..205ba130e48ebe4b6919fef3037da3213cccee46 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..96503ac913823cbdfe1b4685ab9215868667e8b7 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36e5e0bb1474514eb70f5d4ab752037a88c024c9 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f09932e295c0f24013493f07bd391efb255e3468 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..023cd0be6ec606e464fb344c52ef5a5867e39b97 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dcf560c74e2286cbcfc73ed3f7fb95dad50b08c --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..962755f6ef605eadc24ba8d0086fbd4cb0d2f2e0 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "smoothquant.hpp" + +template +using trait_ = smoothquant_traits_; + +template +float smoothquant_dispatch(smoothquant_traits /*t*/, + smoothquant_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd 2p + if(a.n <= 64) { + r = smoothquant_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + return r; + // clang-format on +} + +float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + return smoothquant_dispatch(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return smoothquant_dispatch(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cdf93f6fcfd3d523c7723e28e006d089bf49f350 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp @@ -0,0 +1,62 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "smoothquant.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = smoothquant_args; + +template +using trait_ = smoothquant_traits_; + +template +float smoothquant_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = ck_tile::SmoothquantPipelineProblem< + typename SmoothquantTypeConfig::XDataType, + typename SmoothquantTypeConfig::XScaleDataType, + typename SmoothquantTypeConfig::ComputeDataType, + typename SmoothquantTypeConfig::YScaleDataType, + typename SmoothquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass; + using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Smoothquant; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/12_smoothquant/script/perf_test.sh b/example/ck_tile/12_smoothquant/script/perf_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..741eb32ec12ffddac23906f11542b6bfd72b093e --- /dev/null +++ b/example/ck_tile/12_smoothquant/script/perf_test.sh @@ -0,0 +1,37 @@ + +EXE="$(find . -name tile_smoothquant -type f | head -n 1)" + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/12_smoothquant/script/smoke_test.sh b/example/ck_tile/12_smoothquant/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..d08e063966549b83e72560bf9d10363daefea000 --- /dev/null +++ b/example/ck_tile/12_smoothquant/script/smoke_test.sh @@ -0,0 +1,30 @@ +#!/bin/sh +EXE="$(find . -name tile_smoothquant -type f | head -n 1)" + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed01d654fda239d87f34a6131d4c4d9744f9c859 --- /dev/null +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -0,0 +1,218 @@ +#include "ck_tile/host.hpp" +#include "smoothquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = SmoothquantTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using XScaleDataType = typename TypeConfig::XScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor xscale_host({n}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + xscale_buf.ToDevice(xscale_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + smoothquant_traits traits{data_type}; + + smoothquant_args args{x_buf.GetDeviceBuffer(), + xscale_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + m, + n, + stride}; + + float ave_time = smoothquant( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + + sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + ck_tile::HostTensor y_host({m, n}, {stride, 1}); + // smooth outlier + { + auto f = [&](auto n_) { + auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + + for(int m_ = 0; m_ < m; ++m_) + { + auto v_x = ck_tile::type_convert(x_host(m_, n_)); + y_host(m_, n_) = v_x * v_xscale; + } + }; + + ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + std::thread::hardware_concurrency()); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..26a598db55bc19c5ce9e1035eeef2add79fd1f35 --- /dev/null +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/smoothquant.hpp" +#include + +template +struct SmoothquantTypeConfig; + +template <> +struct SmoothquantTypeConfig +{ + using XDataType = ck_tile::half_t; + using XScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct SmoothquantTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using XScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +// runtime args +struct smoothquant_args : public ck_tile::SmoothquantHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct smoothquant_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a); + +// This is the public API, will be generated by script +struct smoothquant_traits +{ + std::string data_type; +}; + +float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..09f3e4ac4e9b46a4f6f995e8f0679dc0fb2a5c35 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp) +target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS}) diff --git a/example/ck_tile/13_moe_sorting/README.md b/example/ck_tile/13_moe_sorting/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7b6792dd957cb243a8dde35eedad6f9a89d49844 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/README.md @@ -0,0 +1,27 @@ +# moe-sorting + +This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_moe_sorting -j +``` +This will result in an executable `build/bin/tile_example_moe_sorting` + +## example +``` +args: + -v weather do CPU validation or not (default:1) + -pr_i index data type. (currently only fp32 supported now) (default:int32) + -pr_w output weight data type(currently only fp32 supported now) (default:fp32) + -t number of input tokens (default:32) + -e number of experts (default:8) + -k topk (default:2) + -st_i row stride of input, -1 means same as experts (default:-1) + -seed seed to be used, -1 means random every time (default:-1) + -kname when set to 1 it will print kernel name (default:0) + +``` diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2c4df105838d6ae1ebe3ef5058330a74fd55585 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "moe_sorting_api.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("pr_i", "int32", "index data type. (currently only int32 supported now)") + .insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)") + .insert("t", "128", "number of input tokens") + .insert("e", "8", "number of num_experts") + .insert("k", "4", "topk") + .insert("unit", "32", "unit_size") + .insert("moe_buf_size", "0", "moe_buf_size") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "when set to 1 it will print kernel name") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +template +bool test_moe_sorting(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + int tokens = args.get_int("t"); + int num_experts = args.get_int("e"); + int topk = args.get_int("k"); + int seed = args.get_int("seed"); + int unit_size = args.get_int("unit"); + int moe_buf_size = args.get_int("moe_buf_size"); + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + int max_output_ids = + ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); + + if(seed < 0) + { + seed = std::time(nullptr); + } + + if(topk > num_experts) + { + printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n", + topk, + num_experts); + return false; + } + + // tokens already considered batch size + ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor sorted_ids_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_host({max_output_ids / unit_size}, {1}); + ck_tile::HostTensor sorted_id_cnt_host({1}, {1}); + ck_tile::HostTensor moe_buf_host({moe_buf_size}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); + topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); + + ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_expert_ids_dev( + sorted_expert_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + + topk_ids_dev.ToDevice(topk_ids_host.data()); + weights_dev.ToDevice(weights_host.data()); + if(moe_buf_size > 0) + { + moe_buf_dev.ToDevice(moe_buf_host.data()); + } + + moe_sorting_trait trait{index_prec, weight_prec}; + + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, + static_cast(moe_buf_size * sizeof(float))}; + + ck_tile::stream_config sc{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + warmup, + repeat}; + auto ms = moe_sorting(trait, karg, sc); + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + index_prec.c_str(), + weight_prec.c_str(), + tokens, + num_experts, + topk, + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + if(ms < 0) + { + return false; + } + + sorted_ids_dev.FromDevice(sorted_ids_host.data()); + sorted_weights_dev.FromDevice(sorted_weights_host.data()); + sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data()); + sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); + if(moe_buf_size > 0) + { + moe_buf_dev.FromDevice(moe_buf_host.data()); + } + + bool rtn = true; + if(validate) + { + ck_tile::HostTensor sorted_ids_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_ref({max_output_ids / unit_size}, {1}); + + int32_t ref_total_tokens_post_pad = 0; + ck_tile::reference_moe_sorting(topk_ids_host, + weights_host, + sorted_ids_ref, + sorted_weights_ref, + sorted_expert_ids_ref, + ref_total_tokens_post_pad, + num_experts, + unit_size); + rtn &= ck_tile::check_err( + sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); + rtn &= ck_tile::check_err(sorted_weights_host, + sorted_weights_ref, + std::string("OUT Error: Incorrect w!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_expert_ids_host, + sorted_expert_ids_ref, + std::string("OUT Error: Incorrect eid!"), + 1e-6, + 1e-6); + if(moe_buf_size) + { + ck_tile::HostTensor moe_buf_ref({moe_buf_size}); + rtn &= ck_tile::check_err( + moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); + } + rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + } + + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + + bool r = true; + if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0) + { + r &= test_moe_sorting(args); + } + return r ? 0 : -1; +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25e99c5306fddc9786d06d1902ebcdf0171413f9 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_sorting_api.hpp" + +#define MOE_SORTING_DISPATCH(unroll_num_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + using ms_problem = ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + if(a.num_experts > 127) + { + printf("lds size exceed, only support experts <127 \n"); + return -1; + } + if(a.moe_buf_bytes % 16) + { + printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes); + return -1; + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64); + switch(smem_io_unroll_num) + { + case(1): { + MOE_SORTING_DISPATCH(1); + } + case(2): { + MOE_SORTING_DISPATCH(2); + } + case(3): { + MOE_SORTING_DISPATCH(3); + } + case(5): { + MOE_SORTING_DISPATCH(5); + } + case(6): { + MOE_SORTING_DISPATCH(6); + } + case(7): { + MOE_SORTING_DISPATCH(7); + } + case(8): { + MOE_SORTING_DISPATCH(8); + } + case(9): { + MOE_SORTING_DISPATCH(9); + } + case(10): { + MOE_SORTING_DISPATCH(10); + } + case(11): { + MOE_SORTING_DISPATCH(11); + } + default: { + MOE_SORTING_DISPATCH(4); + } + } + } + return -1; +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91b54932cec5bcb94a90374259619c9bfedbb7f0 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/moe_sorting.hpp" + +struct moe_sorting_trait +{ + std::string index_type; + std::string weight_type; // currently always float +}; + +struct moe_sorting_args : public ck_tile::MoeSortingHostArgs +{ +}; + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..1fc5eafcb005e8ce68832f94b28140fc6d848d8b --- /dev/null +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -0,0 +1,19 @@ +# #!/bin/sh + +EXE=./build/bin/tile_example_moe_sorting + +$EXE -t=80 -e=17 -moe_buf_size=16 +$EXE -t=111 -e=117 -moe_buf_size=4 +$EXE -t=1000 -e=55 -moe_buf_size=1024 +$EXE -t=99 -e=120 -moe_buf_size=10244 +$EXE -t=175 -e=64 -k=8 +$EXE -t=65 -e=8 -k=2 +$EXE -t=1 -e=25 +$EXE -t=31 -e=19 -k=15 +$EXE -t=81 -e=37 -k=7 +$EXE -t=23 -e=1 -k=1 +$EXE -t=127 -e=99 -k=19 +$EXE -t=71 -e=11 -k=11 +$EXE -t=1 -e=1 -k=1 +$EXE -t=99 -e=2 -k=1 +$EXE -t=333 -e=99 -k=13 \ No newline at end of file diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..15db0f46c4ec78e047415a58666653397a2f24f8 --- /dev/null +++ b/example/ck_tile/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(01_fmha) +add_subdirectory(02_layernorm2d) +add_subdirectory(03_gemm) +add_subdirectory(04_img2col) +add_subdirectory(05_reduce) +add_subdirectory(06_permute) +add_subdirectory(09_topk_softmax) +add_subdirectory(10_rmsnorm2d) +add_subdirectory(11_add_rmsnorm2d_rdquant) +add_subdirectory(12_smoothquant) +add_subdirectory(13_moe_sorting) diff --git a/example/ck_tile/remod.py b/example/ck_tile/remod.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc0dcf5d7257ba049b26c9ce0f724e6f46bbb95 --- /dev/null +++ b/example/ck_tile/remod.py @@ -0,0 +1,21 @@ +import pathlib +from pathlib import Path +import subprocess +import os +import copy + +all_files = [] +for p in sorted(Path("./").rglob("*")): + if p.suffix in ['.hpp', '.cpp']: + all_files.append(pathlib.PurePath(p)) + + +# formatting +for x in all_files: + subprocess.Popen(f'dos2unix {str(x)}', shell=True) + cmd = f'clang-format-12 -style=file -i {str(x)}' + #for xp in x.parents: + #print(get_file_base(x)) + subprocess.Popen(cmd, shell=True) + +#print(all_files) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 94ba97c6e1404918fe805414e41c339a2010e03d..8b06894c05f006deb9792333f1b8839782380a6e 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/config.h" +#include "ck/utility/env.hpp" #ifndef __HIPCC_RTC__ #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" @@ -11,6 +12,12 @@ #endif #endif +// environment variable to enable logging: +// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED +CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + +// to do: add various levels of logging with CK_LOG_LEVEL + #define CK_TIME_KERNEL 1 // constant address space for kernel parameter @@ -45,16 +52,38 @@ #define CK_USE_WAVES_PER_EU 0 #endif +// define general macros for various architectures +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif +#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) +#define __gfx101__ +#endif +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \ + defined(__gfx10_3_generic__) +#define __gfx103__ +#endif +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx11_generic__) +#define __gfx11__ +#endif +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) +#define __gfx12__ +#endif + // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(__gfx1030__) // for GPU code +#elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -62,36 +91,31 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 +#elif defined(__gfx11__) || defined(__gfx12__) +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 #endif // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_MFMA #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(defined(__gfx90a__) || defined(__gfx94__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) #define CK_USE_AMD_MFMA_GFX940 #endif -// WMMA instruction -#ifndef __HIP_DEVICE_COMPILE__ // for host code -#define CK_USE_AMD_WMMA -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code -#define CK_USE_AMD_WMMA -#endif - // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 @@ -104,15 +128,13 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -131,6 +153,12 @@ // inner product using V_DOT with DPP8 modifiers #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 +// LDS direct loads using inline assembly +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 + +// set rounding to nearest even as default for f8 conversions +#define CK_USE_SR_F8_CONVERSION 0 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -203,17 +231,17 @@ // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 -// flag to enable (1) or disable (0) the debugging output in some kernels -#define DEBUG_LOG 0 - // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 -#elif -// enable only on MI200 +#else +// enable only for gfx90a #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // CK_WORKAROUND_DENORM_FIX +// set flag to 1 to build deprecated instances +#define CK_BUILD_DEPRECATED 1 + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 13dc5da5d179443b603cce5a2aa6274061856db3..0f0b7bd607424a2468f4b62a6cbf1ae6331ee52f 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -43,6 +43,9 @@ #ifndef CK_ENABLE_FP8 #define CK_ENABLE_FP8 "ON" #endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif #ifndef CK_ENABLE_FP16 #define CK_ENABLE_FP16 "ON" #endif @@ -66,6 +69,10 @@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #endif +#ifndef CK_ENABLE_BF8 +#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@ +#endif + #ifndef CK_ENABLE_FP16 #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #endif @@ -91,10 +98,17 @@ #endif // -// Instances supports in the current CK build +// CK kernels which support XDL (MI series) +// +#ifndef CK_USE_XDL +#cmakedefine CK_USE_XDL @CK_USE_XDL@ +#endif + +// +// CK Kernels which support WMMA (recent Navi series) // -#ifndef CK_ENABLE_INSTANCES_ONLY -#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@ +#ifndef CK_USE_WMMA +#cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #endif // clang-format on diff --git a/include/ck/filesystem.hpp b/include/ck/filesystem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41ab5f4ddbf73f03acf32e19a89433c7273beda6 --- /dev/null +++ b/include/ck/filesystem.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef GUARD_CK_FILESYSTEM_HPP_ +#define GUARD_CK_FILESYSTEM_HPP_ + +#include +#include + +// clang-format off +#if defined(CPPCHECK) + #define CK_HAS_FILESYSTEM 1 + #define CK_HAS_FILESYSTEM_TS 1 +#elif defined(_WIN32) + #if _MSC_VER >= 1920 + #define CK_HAS_FILESYSTEM 1 + #define CK_HAS_FILESYSTEM_TS 0 + #elif _MSC_VER >= 1900 + #define CK_HAS_FILESYSTEM 0 + #define CK_HAS_FILESYSTEM_TS 1 + #else + #define CK_HAS_FILESYSTEM 0 + #define CK_HAS_FILESYSTEM_TS 0 + #endif +#elif defined(__has_include) + #if __has_include() && __cplusplus >= 201703L + #define CK_HAS_FILESYSTEM 1 + #else + #define CK_HAS_FILESYSTEM 0 + #endif + #if __has_include() && __cplusplus >= 201103L + #define CK_HAS_FILESYSTEM_TS 1 + #else + #define CK_HAS_FILESYSTEM_TS 0 + #endif +#else + #define CK_HAS_FILESYSTEM 0 + #define CK_HAS_FILESYSTEM_TS 0 +#endif +// clang-format on + +#if CK_HAS_FILESYSTEM +#include +#elif CK_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace CK { + +#if CK_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif CK_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace CK + +inline std::string operator+(const std::string_view s, const CK::fs::path& path) +{ + return path.string().insert(0, s); +} + +inline std::string operator+(const CK::fs::path& path, const std::string_view s) +{ + return path.string().append(s); +} + +#define FS_ENUM_PERMS_ALL fs::perms::all + +#if CK_HAS_FILESYSTEM_TS +#ifdef __linux__ +#include +namespace CK { +inline fs::path weakly_canonical(const fs::path& path) +{ + std::string result(PATH_MAX, '\0'); + std::string p{path.is_relative() ? (fs::current_path() / path).string() : path.string()}; + char* retval = realpath(p.c_str(), &result[0]); + return (retval == nullptr) ? path : fs::path{result}; +} +} // namespace CK +#else +#error "Not implmeneted!" +#endif +#else +namespace CK { +inline fs::path weakly_canonical(const fs::path& path) { return fs::weakly_canonical(path); } +} // namespace CK +#endif + +namespace CK { + +#ifdef _WIN32 +constexpr std::string_view executable_postfix{".exe"}; +constexpr std::string_view library_prefix{""}; +constexpr std::string_view dynamic_library_postfix{".dll"}; +constexpr std::string_view static_library_postfix{".lib"}; +constexpr std::string_view object_file_postfix{".obj"}; +#else +constexpr std::string_view executable_postfix{""}; +constexpr std::string_view library_prefix{"lib"}; +constexpr std::string_view dynamic_library_postfix{".so"}; +constexpr std::string_view static_library_postfix{".a"}; +constexpr std::string_view object_file_postfix{".o"}; +#endif + +inline fs::path make_executable_name(const fs::path& path) +{ + return path.parent_path() / (path.filename() + executable_postfix); +} + +inline fs::path make_dynamic_library_name(const fs::path& path) +{ + return path.parent_path() / (library_prefix + path.filename() + dynamic_library_postfix); +} + +inline fs::path make_object_file_name(const fs::path& path) +{ + return path.parent_path() / (path.filename() + object_file_postfix); +} + +inline fs::path make_static_library_name(const fs::path& path) +{ + return path.parent_path() / (library_prefix + path.filename() + static_library_postfix); +} + +struct FsPathHash +{ + std::size_t operator()(const fs::path& path) const { return fs::hash_value(path); } +}; +} // namespace CK + +#endif // GUARD_CK_FILESYSTEM_HPP_ diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 24b3a50d29977907338a06f5e7f33873f690635b..3950ac596fe11cdafc9b5a85427b609dbd8637eb 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -27,7 +27,7 @@ inline std::string get_device_name() } const std::string raw_name(props.gcnArchName); - // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 static std::map device_name_map = { {"Ellesmere", "gfx803"}, {"Baffin", "gfx803"}, @@ -59,5 +59,42 @@ inline bool is_xdl_supported() ck::get_device_name() == "gfx942"; } +inline bool is_lds_direct_load_supported() +{ + // Check if direct loads from global memory to LDS are supported. + return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; +} + +inline bool is_bf16_atomic_supported() +{ + return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942"; +} + +inline bool is_gfx101_supported() +{ + return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || + ck::get_device_name() == "gfx1012"; +} + +inline bool is_gfx103_supported() +{ + return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || + ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || + ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; +} + +inline bool is_gfx11_supported() +{ + return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; +} + +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + } // namespace ck #endif diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp new file mode 100644 index 0000000000000000000000000000000000000000..918fb28ea903801fff2e6542389a8c7cd00888ee --- /dev/null +++ b/include/ck/host_utility/flush_cache.hpp @@ -0,0 +1,377 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/flush_icache.hpp" +namespace ck { +namespace utility { + +template +struct RotatingMemWrapperMultiD +{ + static constexpr index_t NumDs = DsDataType::Size(); + + using ADataType = decltype(Argument::p_a_grid); + using BDataType = decltype(Argument::p_b_grid); + using DsGridPointer = decltype(Argument::p_ds_grid); + + RotatingMemWrapperMultiD() = delete; + RotatingMemWrapperMultiD(Argument& arg_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_, + std::array size_ds_) + : arg(arg_), + rotating_count(rotating_count_), + size_a(size_a_), + size_b(size_b_), + size_ds(size_ds_) + { + p_a_grids.push_back(arg.p_a_grid); + p_b_grids.push_back(arg.p_b_grid); + p_ds_grids.push_back(arg.p_ds_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + const_cast(p_a_grids[0]), + size_a_, + hipMemcpyDeviceToDevice)); + p_a_grids.push_back(pADeviceBuf); + } + + { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + const_cast(p_b_grids[0]), + size_b_, + hipMemcpyDeviceToDevice)); + p_b_grids.push_back(pBDeviceBuf); + } + + { + + DsGridPointer ds_buffer; + static_for<0, NumDs, 1>{}([&](auto j) { + void* pDDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pDDeviceBuf), size_ds_[j])); + hip_check_error(hipMemcpy(static_cast(pDDeviceBuf), + static_cast(p_ds_grids[0][j]), + size_ds_[j], + hipMemcpyDeviceToDevice)); + + using DDataType = remove_cvref_t>; + + ds_buffer(j) = static_cast(pDDeviceBuf); + }); + + p_ds_grids.push_back(ds_buffer); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_a_grid = reinterpret_cast(p_a_grids[idx]); + arg.p_b_grid = reinterpret_cast(p_b_grids[idx]); + arg.p_ds_grid = p_ds_grids[idx]; + } + } + void Print() + { + std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapperMultiD() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_a_grid = reinterpret_cast(p_a_grids[0]); + arg.p_b_grid = reinterpret_cast(p_b_grids[0]); + arg.p_ds_grid = p_ds_grids[0]; + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); + + static_for<0, NumDs, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_ds_grids[i][j])))); + }); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::size_t size_a = 0; + std::size_t size_b = 0; + std::array size_ds = {0}; + std::vector p_a_grids; + std::vector p_b_grids; + std::vector p_ds_grids; +}; + +template +struct RotatingMemWrapper +{ + using ADataType = decltype(Argument::p_a_grid); + using BDataType = decltype(Argument::p_b_grid); + + RotatingMemWrapper() = delete; + RotatingMemWrapper(Argument& arg_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_) + : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_) + { + p_a_grids.push_back(arg.p_a_grid); + p_b_grids.push_back(arg.p_b_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + const_cast(p_a_grids[0]), + size_a_, + hipMemcpyDeviceToDevice)); + p_a_grids.push_back(pADeviceBuf); + } + + { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + const_cast(p_b_grids[0]), + size_b_, + hipMemcpyDeviceToDevice)); + p_b_grids.push_back(pBDeviceBuf); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_a_grid = reinterpret_cast(p_a_grids[idx]); + arg.p_b_grid = reinterpret_cast(p_b_grids[idx]); + } + } + void Print() + { + std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapper() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_a_grid = reinterpret_cast(p_a_grids[0]); + arg.p_b_grid = reinterpret_cast(p_b_grids[0]); + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::size_t size_a = 0; + std::size_t size_b = 0; + std::vector p_a_grids; + std::vector p_b_grids; +}; + +inline void flush_icache() +{ + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; + + ck::flush_icache<<>>(); + hip_check_error(hipGetLastError()); +} +// if TimePrePress == false, return time does not include preprocess's time +template +float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + GemmArgs& gemm_args, + Args... args) +{ +#if CK_TIME_KERNEL +#define MEDIAN 0 + if(stream_config.time_kernel_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } + // warm up + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + } + + const int nrepeat = stream_config.nrepeat_; + if(nrepeat == 0) + { + return 0.0; + } + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } + +#if MEDIAN + std::set times; +#else + float total_time = 0; +#endif + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + if constexpr(!TimePreprocess) + { + preprocess(); + } + + // hipEvent_t start, stop; + + // hip_check_error(hipEventCreate(&start)); + // hip_check_error(hipEventCreate(&stop)); + + // hip_check_error(hipDeviceSynchronize()); + // hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + // calculate preprocess time + if constexpr(TimePreprocess) + { + preprocess(); + } + // run real kernel + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + // end real kernel + + // hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + // hip_check_error(hipEventSynchronize(stop)); + // float cur_time = 0; + // hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); + // #if MEDIAN + // times.insert(cur_time); + // #else + // total_time += cur_time; + // #endif + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; + + printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", + static_cast(gemm_args.p_a_grid), + static_cast(gemm_args.p_b_grid)); + } + } + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + float cur_time = 0; + hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); +#if MEDIAN + times.insert(cur_time); +#else + total_time += cur_time; +#endif + +#if MEDIAN + auto mid = times.begin(); + std::advance(mid, (nrepeat - 1) / 2); + if(nrepeat % 2 == 1) + { + return *mid; + } + else + { + auto mid_next = mid; + std::advance(mid_next, 1); + return (*mid + *mid_next) / 2; + } +#else + // return total_time / nrepeat; + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01; + return (total_time - preprocess_offset * nrepeat) / nrepeat; +#endif + } + else + { + preprocess(); + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} + +} // namespace utility +} // namespace ck diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index af7bebd9d6afbf59f1852156b1708ed843d9a9f6..c0894f1d70e439e51e190905f79d3dcd31637c58 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -3,15 +3,32 @@ #pragma once +#include #include +// To be removed, which really does not tell the location of failed HIP functional call inline void hip_check_error(hipError_t x) { if(x != hipSuccess) { std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ - << "in function: " << __func__; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " + << "hip_check_error.hpp" + << ": " << __LINE__ << "in function: " << __func__; throw std::runtime_error(ss.str()); } } + +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" \ + << "hip_check_error.hpp" \ + << "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ + } while(0) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index c3336be6c84c6c0d591665b5310365d39fdd9f7f..5c1c1c4e6049214a60095471b52e7af21d0a07a1 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -20,25 +20,31 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up 1 time\n"); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up - kernel<<>>(args...); + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } - const int nrepeat = 10; -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + const int nrepeat = stream_config.nrepeat_; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); @@ -50,6 +56,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, for(int i = 0; i < nrepeat; ++i) { kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -59,16 +66,21 @@ float launch_and_time_kernel(const StreamConfig& stream_config, hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + return total_time / nrepeat; } else { kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif @@ -86,26 +98,32 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up 1 time\n"); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up preprocess(); - kernel<<>>(args...); + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } - const int nrepeat = 10; -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + const int nrepeat = stream_config.nrepeat_; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); @@ -118,6 +136,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -127,17 +146,22 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + return total_time / nrepeat; } else { preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 505a602b240428bcd1f0f81018fef0c4716c80b7..37ba250cf5d58bbfe3b4a9e50fadfd9a4fa36c5a 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -11,4 +11,9 @@ struct StreamConfig hipStream_t stream_id_ = nullptr; bool time_kernel_ = false; int log_level_ = 0; + int cold_niters_ = 5; + int nrepeat_ = 50; + + bool flush_cache = false; + int rotating_count = 1; }; diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index ae3139ce78c8b3b881ee36602011dd895532bd83..c152cbfb1ea1eed1a77e1172b4dd3c398ebca98b 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1951,4 +1951,96 @@ struct Modulo printf("}"); } }; + +template +struct Xor +{ + using LowerIndex = MultiIndex<2>; + using UpperIndex = MultiIndex<2>; + + using UpLengths = LowLengths; + + UpLengths up_lengths_; + + __host__ __device__ constexpr Xor() : up_lengths_{} {} + + __host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + + if constexpr(ApplyModulo) + { + idx_low(Number<1>{}) = + idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + } + else + { + idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}]; + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up, + Number) const + { + static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 && + UpIdx::Size() == 2, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + + CalculateLowerIndex(idx_low, idx_up); + + idx_diff_low = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Xor{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + printf("}"); + } +}; } // namespace ck diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index af0a8a34d0e48bde494dacfff5bbb17eda5eb479..8feadf63c604b562c8fb27bef02525d1c4e1929a 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -127,4 +127,16 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, { return Modulo{modulus, up_length}; } + +template +__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp deleted file mode 100644 index e527509f57af60ef9065ccdbf8c2112d7ed336af..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp +++ /dev/null @@ -1,370 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp" - -namespace ck { - -/** - * DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit - * the data loaded from LDS to registers. - * - * The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C - * between them in such a way that threads from the same group need the same chunk of either - * matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the - * whole chunk from LDS to its own register space. - * Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size` - * of the chunk, and then share that data with other threads from the same lane group. - * - * Assumptions coming from the usage of DPP8: - * 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or - * `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` - - * - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either - * matrix A or B; - * - based on these values we determine which matrix to share. - * 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or - * `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) - - * - we have to make sure that the data to split is divisible by the number of - * threads in the group. - * - * General algorithm: - * C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] - * A and B are visible to the whole block, C is distributed among each thread - * Assume: - * 1. A: - * 1. ABlockDesc_BK0_BM_BK1 is known at compile-time - * 2. ABlockBuffer is DynamicBuffer - * 2. B: - * 1. BBlockDesc_BK0_BN_BK1 is known at compile-time - * 2. BBlockBuffer is DynamicBuffer - * 3. C: - * 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time - * 2. CThreadBuffer is StaticBuffer - * 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 - */ -template - typename BM10BN10ThreadClusterBN10Xs, // Sequence - index_t AThreadCopyScalarPerVector_BM11, - index_t BThreadCopyScalarPerVector_BN11, - typename enable_if::type = false> -struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0 -{ - using AIndex = MultiIndex<4>; - using BIndex = MultiIndex<4>; - using CIndex = MultiIndex<4>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); - static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); - static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); - static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); - - static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; - static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; - - static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; - static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; - - static constexpr index_t BM11 = BM1PerThreadBM11; - static constexpr index_t BN11 = BN1PerThreadBN11; - - static constexpr index_t BM1 = BM100 * BM101 * BM11; - static constexpr index_t BN1 = BN100 * BN101 * BN11; - - static constexpr index_t BM0 = BM / BM1; - static constexpr index_t BN0 = BN / BN1; - - // We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all - // threads in a lane group need the same chunk of B or A matrices and we can share them using - // DPP. - static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size); - static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false; - static constexpr bool ShareA = !ShareB; - - // If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`, - // respectively) elements, so we split them between threads in lane group so each thread loads - // less data from LDS. - static constexpr index_t BM1PerThread = - ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11; - static constexpr index_t BN1PerThread = - ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11; - - __host__ __device__ static constexpr auto - MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) - { - const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( - a_block_desc_bk0_bm_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return a_block_bk0_bm0_bm1_bk1; - } - - __host__ __device__ static constexpr auto - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) - { - const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( - b_block_desc_bk0_bn_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return b_block_desc_bk0_bn0_bn1_bk1; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM, BN] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM0, BM1, BN0, BN1] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; - } - - __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() - { - return Sequence{}; - } - - static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = - MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); - - static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); - - public: - __device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0() - : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id())}, - a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()}, - b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()} - { - static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && - BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); - - static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == - BBlockDesc_BK0_BN_BK1{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && - BM10BN10ThreadClusterBN10Xs::Size() == 2, - "wrong!"); - } - - __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) - { - // lower: [BM0, BM1, BN0, BN1] - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - constexpr auto adaptor0 = - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); - - // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // upper: [Tid, BM0, BM11, BN0, BN11] - constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), - make_pass_through_transform(BM0), - make_pass_through_transform(BM11), - make_pass_through_transform(BN0), - make_pass_through_transform(BN11)), - make_tuple( - Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); - - return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); - } - - __device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1() - { - const auto offsetBM0 = c_thread_origin_data_idx_[I0]; - // If sharing matrix A, we need a separate BM1 offset for each thread in lane group. - const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] + - dpp8::get_thread_idx_in_lane_group() * BM1PerThread - : c_thread_origin_data_idx_[I1]; - return make_tuple(0, offsetBM0, offsetBM1, 0); - } - - __device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1() - { - const auto offsetBN0 = c_thread_origin_data_idx_[I2]; - // If sharing matrix B, we need a separate BN1 offset for each thread in lane group. - const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] + - dpp8::get_thread_idx_in_lane_group() * BN1PerThread - : c_thread_origin_data_idx_[I3]; - return make_tuple(0, offsetBN0, offsetBN1, 0); - } - - template - __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, - const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - auto a_thread_buf = make_static_buffer( - a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); - - constexpr auto threadwise_contraction = - ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< - FloatA, - FloatB, - FloatC, - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - CThreadDesc_BM0_BM11_BN0_BN11, - Sequence, - Sequence<1, BM1PerThreadBM11>, - Sequence<1, BN1PerThreadBN11>, - ShareA>{}; - - static_for<0, BN0, 1>{}([&](auto bn0) { - static_for<0, BM0, 1>{}([&](auto bm0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - - static_for{}([&](auto bk0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(bk0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(bk0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - }); - }); - }); - } - - private: - // A[BK0, BM0, BM1, BK1] - static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - // B[BK0, BN0, BN1, BK1] - static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatA, - FloatA, - decltype(a_block_desc_bk0_bm0_bm1_bk1_), - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatB, - FloatB, - decltype(b_block_desc_bk0_bn0_bn1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - CIndex c_thread_origin_data_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f03427a7ead1cfef5450f04dfa0dda7b38e2763f --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..438d7d8ac3c10d871c45e203bf98311d4647bc49 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -0,0 +1,994 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +// Double LDS buffer +// Prefetech 2 stage +// Local prefetch 1 stage + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_pipeline_v4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + + // HotLoopInstList::Print(); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __device__ static constexpr auto HotLoopScheduler() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_buffer_load_inst = + HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA + }); + } + + template + __device__ static constexpr auto TailScheduler() + { + } + + template <> + __device__ constexpr auto TailScheduler<1>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_write_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA + }); + } + + template <> + __device__ constexpr auto TailScheduler<2>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_read_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA + }); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + // Inst List: + // ds_read_b128: 16 + // ds_write_b128: 8 + // buffer_load_dwordx4: 16 + // v_mfma: 0 + // ------------------------------------------------------------------------------------------- + + // Global prefetch 1th, Fill Ping LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Local prefetch 1th, Fill Ping Reg + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 2th, Fill Pong LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3rd + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + // ------------------------------------------------------------------------------------------- + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 2; + } while(i < (num_loop - 3)); + } + + // tail + if constexpr(TailNum == 3) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == 2) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1ab460fa8aadf0896e17f747327d17f9cb92032e --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp" + +namespace ck { + +enum struct BlockGemmPipelineVersion +{ + v1, // Naive + v2, // Mem + v3, // Comp +}; + +template +constexpr auto BlockGemmABScalePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_ab_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2_ab_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_ab_scale{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp new file mode 100644 index 0000000000000000000000000000000000000000..036a40488a7ac13e78c81891bdd1c2d6bf23b8fc --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = + ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2558b58c2ce857851b67458dd13c61ce04556f86 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +enum struct BlockGemmPipelineVersion +{ + v1, // Naive + v2, // Mem + v3, // Comp + v4, // Comp, double lds buffer + v5, // Comp, double global prefetch register buffer +}; + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return BlockwiseGemmXdlops_pipeline_v4{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) + { + return BlockwiseGemmXdlops_pipeline_v5{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f597573dc2a46660ba6f10f8a61050a502f0e2ad --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -0,0 +1,731 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, + // but except the first, as we can shorten non-MAC cluster a bit and there's no + // observable negative impact. The desired effect is waves in a workgroup + // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC + // resource from other workgroups and reducing the chance of latency hiding by + // waiting for the rest of the workgroup at the eventual sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by + // applying small delays to different wavefronts It is performed + // near the end of MAC cluster to minimize lgkmcnt penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..821bbb0051916a378ac762dbbb61a930b21abb23 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_ab_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_ab_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..711c47854adad7b2880718f69ec3febe05984bb4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -0,0 +1,1156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40fa776484342fcc3381a5f42ae3a1bd221a90c7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -0,0 +1,633 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2_ab_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_ab_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Initialize C + c_thread_buf.Clear(); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..171a232c0f9b2c61cfeabfed62ba1ba199ae0d1b --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -0,0 +1,463 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..de542866a635395fea0722c7f9d62a78dffa6630 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -0,0 +1,534 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_ab_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_ab_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = 4; + // HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = 4; + // HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[I0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd5a1bedf537c3a4a31d53cae5f2d5ca1beeabb9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -0,0 +1,594 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimimal pipeline with highest resource request +// GlobalPrefetchStages: 4 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v4 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v4 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 4; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + template + __device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group) + { + // TODO: Take data type into consideration as pipe ver 3 + // A-B splited schedule + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_a = + (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a; + constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a; + + constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = + (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b; + constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b; + + constexpr auto num_mfma_per_issue = + HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b); + + static_for<0, num_issue_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_a, + schedule_group); // MFMA + }); + + static_for<0, num_issue_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_b, + schedule_group); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0); + + // Local prefill 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 4 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + auto LoopFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto vmem_buf, + auto mfma_reg_buf, + auto schedule_group) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + }); + + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(schedule_group); + }; + + LoopFunc(I1, I1, I0, I0, I0, I0); + LoopFunc(I0, I0, I1, I1, I1, I0); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto vmem_buf, + auto mfma_reg_buf, + auto schedule_group) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(schedule_group); + }; + + auto ReadCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto mfma_reg_buf, + auto schedule_group) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(schedule_group); + }; + + auto CompFunc = [&](auto mfma_reg_buf) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + // tail + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); + ReadCompFunc(I0, I0, I1, I1); + CompFunc(I0); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); + ReadWriteCompFunc(I0, I0, I1, I1, I1, I1); + ReadCompFunc(I1, I1, I0, I1); + CompFunc(I1); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6a4f05502ad03d28a98ca6d351ee613e0c4f56c --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v5 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v5 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; + constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; + + constexpr auto num_dsread_stage1_a_mfma = + (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage1_b_mfma = + (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + constexpr auto num_dsread_stage3_a_mfma = + (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage3_b_mfma = + (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate - + num_ds_read_inst_b / ds_read_b_mfma_rate; + constexpr auto num_mfma_per_issue = + num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + // stage 1 + static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // stage 2 + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 3 + static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // IGLP COMPILER BUG: + // If comment out following scheduler barrier would cause sanity fail. + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto vmem_buf) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + } + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I0); + LoopFunc(I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + // tail + auto ReadWriteCompFunc = [&](auto vmem_buf) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); + + block_sync_lds(); + } + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + }; + auto ReadCompFunc = [&]() { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat - 1, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + HotLoopScheduler(); + }; + + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0); + ReadWriteCompFunc(I1); + ReadCompFunc(); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I0); + ReadCompFunc(); + } + } + + protected: + // A[MRepeat, I1, I1, KPack] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); + + // B[NRepeat, N1, N2, KPack] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9f9b0be7eab7fcb89435cf3e51b1b3b0c62abd0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -0,0 +1,453 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +__host__ __device__ static constexpr auto +MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) +{ + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + SparseXdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0, + "MPerBlock must be divisible by MPerXDL * MRepeat"); + static_assert(NPerBlock % (NPerXDL * NRepeat) == 0, + "NPerBlock must be divisible by NPerXDL * NRepeat"); + + static_assert( + KPack % (16 * sizeof(ComputeTypeA)) == 0, + "KPack must be divisbile by number of elements processed in single smfmac instruction"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + // Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4 + // structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in + // smfmac instruction + template + __device__ void SetIdxSqueezeA(AThreadBuf& a_thread_buf, IdxBuf& idx_buf) + { + static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; + static constexpr int32_t processed_elems = 16 / sizeof(ComputeTypeA); + + static_for<0, num_elems, processed_elems>{}([&](auto i) { + constexpr int idx_reg_num = i / (16 * sizeof(ComputeTypeA)); + constexpr int idx_reg_part = (i % 32) / processed_elems; + + vector_type a_thread_vec; + static_for<0, processed_elems, 1>{}([&](auto j) { + a_thread_vec.template AsType()(j) = a_thread_buf + [Number{}]; + }); + + uint8_t idx = 0b11101110; // set to last 2 elems for both 4-elems subgroups by default + for(int j = 0; j < processed_elems; j += 4) + { + int32_t a_pos = idx_reg_part * processed_elems + j; + int32_t nonzero_pos = 0; + ComputeTypeA nonzero_elems[2] = {a_thread_vec[j + 2], a_thread_vec[j + 3]}; + for(int k = 0; k < 3; k += 1) + { + if(a_thread_vec[j + k] != 0.0f) + { + nonzero_elems[nonzero_pos] = a_thread_vec[j + k]; + idx &= ~bit_clear_masks[j / 2 + nonzero_pos]; + idx |= k << 2 * (j / 2 + nonzero_pos); + ++nonzero_pos; + } + } + a_thread_vec[j / 2] = nonzero_elems[0]; + a_thread_vec[j / 2 + 1] = nonzero_elems[1]; + } + IdxBuf[idx_reg_num].AsType()[Number{}] = idx; + + static_for<0, processed_elems / 2, 1>{}([&](auto j) { + a_thread_buf[Number{}] = a_thread_vec[j]; + }); + }); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + static constexpr int32_t elems_per_idx = 16 * sizeof(ComputeTypeA); + auto idx_buf = make_static_buffer( + (a_thread_desc_.GetElementSpaceSize() + elems_per_idx - 1) / elems_per_idx); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + SetIdxSqueezeA(a_thread_buf, idx_buf, a_thread_desc_.GetElementSpaceSize()); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + // a_thread_vec is smaller because it's structurally sparse 2:4 + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type idx_vec; + + static_for<0, KPack / 2, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(2 * i) = b_thread_buf + [Number{}]; + }); + + static_for<0, KPack / elems_per_idx, 1>{}([&](auto i) { + idx_vec.template AsType()(i) = idx_buf[k / elems_per_idx + i]; + }); + + // A is smaller because it's structurally sparse 2:4 + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + using mfma_input_type_idx = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + idx_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 5ec964bd3fdbda9319d1c6eab0ccb207775a6137..fa389c3402d1052c317ea2139d7299a627208da6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,34 +7,56 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #define CK_MNK_LOOP namespace ck { +#ifdef __gfx12__ template -/* A: K0PerBlock x MPerBlock x K1 + index_t KPack, + bool AEnableLds = true, + bool BEnableLds = true, + bool TransposeC = false> +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle +struct BlockwiseGemmWMMA { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; static constexpr auto WmmaK = Number<16>{}; using ThisThreadBlock = ThisThreadBlock; @@ -42,18 +64,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. static constexpr index_t WaveSize = 32; - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr auto wmma_gemm = - WmmaGemm{}; + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -79,26 +100,39 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); } + // Default, Block buffer in LDS, thread level offset enabled __device__ static auto CalculateAThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } } __device__ static auto CalculateBThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } } template @@ -129,10 +163,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, @@ -143,32 +193,68 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle "wrong!"); } - // Thread level, register decriptor. Vector-write + // transposed WMMA output C' = B' * A' __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() { constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( // |MRepeat |MWave |MSubGroup |NRepeat |NWave // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); } - // Provide dimension size + template __host__ __device__ static constexpr auto - GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() { constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = make_naive_tensor_descriptor_packed(make_tuple(Number{}, @@ -179,37 +265,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle Number{})); return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } + // Describe how data allocated in thread copy src buffer // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -221,112 +301,259 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; + static_assert(KPack % (A_K1 * A_KRow) == 0, ""); + static_assert(KPack % (B_K1 * B_KRow) == 0, ""); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); - }); + } } protected: - // A[K0, M0, M1, M2, K1] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // B[K0, N0, N1, N2, K1] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; -}; + template + struct AThreadCopySelector; -// block wise level pipe designed for inline asm + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + false>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#else template -/* A: K0PerBlock x MPerBlock x K1 + index_t KPack, + bool AEnableLds = true, + bool BEnableLds = true, + bool TransposeC = false> +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO +struct BlockwiseGemmWMMA { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; static constexpr auto WmmaK = Number<16>{}; using ThisThreadBlock = ThisThreadBlock; @@ -334,18 +561,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. static constexpr index_t WaveSize = 32; - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = AEnableLds ? 1 : 2; + static constexpr index_t B_KRow = BEnableLds ? 1 : 2; + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr auto wmma_gemm = - WmmaGemm{}; + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -371,26 +596,39 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); } + // Default, Block buffer in LDS, thread level offset enabled __device__ static auto CalculateAThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } } __device__ static auto CalculateBThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } } template @@ -421,10 +659,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO() + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, @@ -434,27 +688,42 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } - // Thread level, register decriptor. Vector-write + + // transposed WMMA output C' = B' * A' __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() { constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( // |MRepeat |MWave |MSubGroup |NRepeat |NWave // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); } template @@ -479,9 +748,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } - // Provide dimension size + // transposed WMMA output C' = B' * A' __host__ __device__ static constexpr auto - GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() { constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = make_naive_tensor_descriptor_packed(make_tuple(Number{}, @@ -492,37 +761,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO Number{})); return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } + // Describe how data allocated in thread copy src buffer // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -534,268 +797,234 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr auto RepeatDiff = MRepeat - NRepeat; - // Read all Mrepeat, Nrepeat - static_for<0, NRepeat, 1>{}([&](auto iN) { - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - }); - - static_for<0, MRepeat, 1>{}([&](auto iM) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - }); - - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut) { - static_for<0, NRepeat, 1>{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - if constexpr(KPerBlock > WmmaK) - { - // Read Consumed Next inner loop A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - } - }); - - static_for{}([&](auto iWmmaK) { - // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { - // Row Repeatation - static_for{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); }); - - // Read Consumed Next inner loop A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple( - Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - - // Col Repeatation - static_for{}([&](auto iM) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - // Read Consumed Next inner loop B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - }); - - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut) { - static_for<0, NRepeat, 1>{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - if constexpr(KPerBlock > WmmaK) - { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - } - }); - }); - - // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { - // Row Repeatation - static_for{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); }); - - // Col Repeatation - static_for{}([&](auto iM) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - }); + } } protected: - // A[M0, M1, M2, K0 = WmmaK] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // B[N0, N1, N2, K0 = WmmaK] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? false : true>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? true : false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; }; +#endif } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d5a64d7aa6fe6c96f384d3f00af642231b6b2c53..d3f6344c27eb56ef265aca802687ae3757881e96 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,30 +1,16 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ -#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING - return LoopScheduler::Interwave; -#else - return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING -} - template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) @@ -42,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) } template + index_t KPack, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { static constexpr auto I0 = Number<0>{}; @@ -72,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -308,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -332,26 +322,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -370,8 +361,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -380,8 +371,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -399,7 +390,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 template struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + KPack, + ComputeTypeA, + ComputeTypeB> { using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + KPack, + ComputeTypeA, + ComputeTypeB>; #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING using Base::a_block_desc_m0_m1_m2_k; @@ -454,9 +454,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -487,26 +487,35 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // sync point. if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) { +#ifdef __gfx12__ + asm volatile("\ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("s_barrier" ::); +#endif __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -527,10 +536,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // TODO: insert setprio in more precise manner since we // could have more than >1 MFMA instructions in single call - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { __builtin_amdgcn_sched_barrier(0); @@ -555,8 +563,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, I1, I1, Number{})); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -565,8 +573,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -582,7 +590,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 }; template + LoopScheduler LoopSched, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } }; @@ -632,26 +649,27 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() * 3. configurable k index starting position and step size after each FMA/XDL instruction */ -template {}.K0PerXdlops, - index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -668,7 +686,8 @@ struct BlockwiseGemmXdlops_v2 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -781,11 +800,6 @@ struct BlockwiseGemmXdlops_v2 "wrong!"); } - __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) - : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) - { - } - // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() { @@ -954,10 +968,9 @@ struct BlockwiseGemmXdlops_v2 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 8ae1ba3f34c10d8359b470ed1172f1b70a7fa8b5..287c6701c3391de0d5b34d7a4da55bba76ae7d6a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a737c9195b436b592cffc89dc272706d80cb8fb4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +/** + * Transfer that uses direct load instructions to copy data from global to LDS memory. + * + * Traditional loads first copy data from global to registers, and then from registers to LDS. + * Direct loads do not need an intermediate step, data is copied directly from global to LDS, + * without the use of additional registers. + * + * However, the instruction has limitations: + * - each thread must copy exactly a single DWORD - 4 bytes; + * - threads within a single wavefront must write consecutive DWORDS into LDS, + * (data in global do not need to be contiguous, each thread might have its own offset). + * + * To make sure that all the transfers finished, the `waitcnt` instruction must be used with + * `vmcnt` instead of `lgkmcnt`. + * + * Limitations of the transfer class: + * - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight; + * - `DstVectorDim` must be the last dimension; + * - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1; + * - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B + * (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16, + * `ScalarPerVector` must be 2); + * - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be + * the same dimension; + * - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64, + * they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way + * to guarantee that. + */ +template +struct ThreadGroupTensorSliceTransfer_DirectLoad +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto block_slice_lengths = BlockSliceLengths{}; + static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + + static constexpr auto thread_single_load_size = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + // After a load, each thread moves by `thread_steps` instead of loading the next elements. + // It makes the whole wavefront load contiguous memory, what is required for direct loads. + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; + + static __device__ constexpr bool AreThreadClusterLengthsValid() + { + // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to + // LDS by the threads from a single wavefront. + // Examples (assuming 64 threads in a wavefront, 128 in a thread block): + // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp32 -> ScalarPerVector = 1 + // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of + // [0, 4, 0]. + // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration, + // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs). + // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp16 -> ScalarPerVector = 2 + // NOTE: ThreadClusterLengths must take into account that each thread writes two + // elements (single DWORD) along the contiguous dimension. + // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write + // 8 * 2 elements of K1PerBlock and there are only 8; + // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32 + // writes [1, 0, 0] instead of [0, 8, 0]. + // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the + // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive + // elements = 64 consecutive DWORDs. + int num_contiguous_dwords = 1; + bool is_contiguous = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(is_contiguous) + { + num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1]; + } + if(thread_slice_lengths[nDim - i - 1] > 1) + { + is_contiguous = false; + } + }); + constexpr index_t wavefront_size = get_warp_size(); + const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0; + + bool thread_slice_lengths_correct = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(thread_slice_lengths[i] <= 0) + { + thread_slice_lengths_correct = false; + } + }); + + return wave_contiguous && thread_slice_lengths_correct; + } + + __device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + + { + static_assert(ck::is_same_v, + "Direct load transfer does not support datatypes conversion. Source and " + "destination data types must be the same."); + + static_assert( + DstVectorDim == nDim - 1, + "Direct load transfer requires the destination vector dimension to be the last one."); + + static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, + "When loading more than one element per thread at once, the contiguous " + "dimension must be the same between source and destination."); + + constexpr auto dword_bytes = 4; + constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + static_assert(bytes_per_thread_load == dword_bytes, + "Direct load transfer requires each thread to load exactly a single " + "DWORD of data."); + + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size(), + "Inconsistent number of dimensions across lengths and descriptors."); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "The number of threads cannot be less than the number of elements in " + "thread cluster lengths."); + + static_assert( + AreThreadClusterLengthsValid(), + "Thread cluster lengths are incorrect. They must be set in a way that allows a single " + "wavefront to write contiguous DWORDs into LDS memory. "); + + const auto thread_cluster_idx = + thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + + SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + src_slice_origin_ = src_slice_origin_idx; + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_slice_origin_ = dst_slice_origin_idx; + } + + __device__ void ResetDstSliceWindow(const DstDesc& dst_desc) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + static_assert( + ck::is_same_v, remove_cvref_t>, + "SrcBuffer and SrcData data types must be consistent."); + static_assert( + ck::is_same_v, remove_cvref_t>, + "DstBuffer and DstData data types must be consistent."); + + constexpr auto dst_access_lengths = thread_slice_lengths; + + const auto dst_forward_steps = generate_steps(dst_desc, 1); + const auto dst_backward_steps = generate_steps(dst_desc, -1); + const auto src_forward_steps = generate_steps(src_desc, 1); + const auto src_backward_steps = generate_steps(src_desc, -1); + + // Loop over the destination block and copy data. + static_ford{}([&](auto ordered_dst_access_idx) { + const auto src_offset = src_coord_.GetOffset(); + const auto dst_offset = dst_coord_.GetOffset(); + + // Check if src data is not in the logic padding area. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + src_buf.template DirectCopyToLds, ScalarPerVector>( + dst_buf, src_offset, dst_offset, is_src_valid); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // Decide whether to move forward or backward. + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]); + } + else + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]); + } + } + }); + }); + + // Reset the destination slice since the entire buffer has been already filled. + ResetDstSliceWindow(dst_desc); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + src_slice_origin_ = src_slice_origin_ + step; + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_); + } + + template + __device__ auto generate_steps(const DescType& desc, int sign) + { + return generate_tuple( + [&](auto i) { + Index step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + private: + static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}); + + SrcCoord src_coord_; + DstCoord dst_coord_; + Index src_slice_origin_; + Index dst_slice_origin_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab826bb041644220365889ae1bc6b2a163599f80 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer with dequantization + * + * RunRead would load low-precision data and scale data. + * RunWrite would process dequantization process. + * Assume Scale is identical along K-dimension + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r1_dequant +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto scale_thread_slice_lengths = + BlockScaleSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_block_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + scale_desc, + make_zero_multi_index(), + scale_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{} && + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetScaleSliceOrigin( + scale_desc, scale_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + // With the assumption, scale scratch is always one + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunScaleRead(scale_desc, scale_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + // We don't prefer use this API directly + /* + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + */ + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + // With the assumption, scale buffer don't need move slice window method + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1_dequant; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aa1f7c573597d2174c46244b80b65c5dd478753f --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r2 +{ + static constexpr index_t nDim = + remove_reference_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + + { + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffer& src_bufs, + const DstDescs& dst_descs, + DstBuffer& dst_bufs, + Number thread_scratch_id) + { + RunRead(src_descs, src_bufs, thread_scratch_id); + RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1c4de5ed3153f2cc8ca7f4a8ccf57741085fba40 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r2 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46d0c6ac2eb4eeb0dfad627834b6d7fbaa991ee7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r3 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp b/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dc08a2c88b2abbd9b26412121ae223af6c3e1c01 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace conv_tensor_rearrange_op { + +struct BaseConvTensorRearrangeOp +{ +}; + +struct ImageToColumn : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Image to Column"; +}; + +struct ColumnToImage : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Column to Image"; +}; + +template ::value, + bool>::type = false> +std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&) +{ + os << Op::name; + return os; +} + +} // namespace conv_tensor_rearrange_op +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index adfa1689c66509c8c194985365356740d4c90473..0eef827a5b5018efeee94d11f1b768c739d85648 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization Filter1x1Pad0, Filter1x1Stride1Pad0, OddC, + Filter3x3, }; inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) @@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::OddC: return "OddC"; + case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; default: return "Unrecognized specialization!"; } } diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index f59021a8a079f9e9f8adcfa7f8c406c8515712c8..08371ed885365f6135e80e1caac944826310b2c8 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -62,7 +62,9 @@ struct BaseOperator }; virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } - virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + virtual void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const { assert(p_arg); p_arg->p_workspace_ = p_workspace; diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp index f18dc3290600e63edb869d93dcb1a903206e32ab..58c0288e8f04c72c463d1fdd8fc59cda424d3a2f 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,6 +53,47 @@ struct DeviceBatchedGemmMultiD : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +template +struct DeviceBatchedGemmV2MultiD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index 8484212118c796f68d1bf066eab7b8ddbb231797..44dedeeef95c94487a7efd732033e1a4bf8c5dab 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" @@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator CElementwiseOperation c_element_op, ck::index_t KBatch = 1) = 0; - virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::unique_ptr MakeInvokerPointer() = 0; virtual std::size_t GetWorkspaceSize(index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, - index_t StrideC) = 0; + index_t StrideC) const = 0; }; template + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M0, M1, ... K0, K1, ...], ... +// input : B0[N0, N1, ... K0, K1, ...], ... +// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ... +// output : E[M0, M1, ... N0, N1, ...] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index 118ade8978643f10dcfafbefb552fe485d6697e4..dffb51d0ba7517446df5c9a3901a10be33b99db8 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -33,7 +33,8 @@ template + typename CDEElementwiseOperation, + typename ComputeDataType = ADataType> struct DeviceContractionMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b28204fd841cd1430e30a255dacf5b9a4583a929 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief Convolution Tensor Rearrange. + * + * This Device operator supports converting an image to + * the GEMM representation (Image to Column) and + * converting a GEMM form to the image (Column to Image). + * Supported layouts: + * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C] + * [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C] + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ImageLayout Input Layout. + * \tparam InputDataType Input Data Type. + * \tparam OutputDataType Output Data Type. + * \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage. + */ +template +struct DeviceConvTensorRearrange : public BaseOperator +{ + + /** + * \brief Make argument pointer for image to column. + * + * \param p_in A pointer to the device memory of the input image. + * \param p_out A pointer to the device memory of the output. + * \param G Convolution number of groups. + * \param N Convolution batch size. + * \param C Convolution number of channels. + * \param input_spatial_lengths Input spatial lengths. + * \param filter_spatial_lengths Filter spatial lengths. + * \param output_spatial_lengths Output spatial lengths. + * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. + * \param gemm_g_m_k_strides Gemm form strides. + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Convolution left pads. + * \param input_right_pads Convolution right pads. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_out, + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac6ff0a960261f633c674b1608f2b2a36fdf17b4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceElementwise from device_elementwise.hpp. + */ +template +struct DeviceElementwise : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +template +using DeviceElementwisePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp new file mode 100644 index 0000000000000000000000000000000000000000..acb18efabfe2198955d692d0292ec77d3bd4f58d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline +// As input tensor thread buffer declared inside blockwise-gemm pipeline. + +template +struct DeviceGemm_dequantB : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbb9fadc6d6f3f03673f4c5a7f13229f29b5d849 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 5abe7ea53f4f5469af6f7bdd4f65cd4ce0219be0..0295e5b0a9964c2568d0334a76d329750a7a8ae9 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -56,6 +56,49 @@ struct DeviceGemmMultipleD : public BaseOperator #endif }; +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7171715250bb404520bae9c289e8fe07c1366207 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_ABScale : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp index 6407aa7e09b81a4dc09dcb1bdca5701aa86d9e14..dfc27e726eaafaddf13ac62c530567eb32ab0ddc 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -20,7 +20,8 @@ template + typename CElementwiseOperation, + typename ComputeType = CDataType> struct DeviceGemmSplitK : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, @@ -48,7 +49,8 @@ template + typename CElementwiseOperation, + typename ComputeType = CDataType> using DeviceGemmSplitKPtr = std::unique_ptr>; + CElementwiseOperation, + ComputeType>>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1a4d684f1457fb89a81e4b15e723b7969be1cdde --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Streamk_V2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t Streamk_sel, + ck::index_t Grid_size, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b2db35b159eb709bb3a12dc6c09d7426aa1a323a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmV2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceGemmV2R1 : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array DsStrides, + ck::index_t StrideC, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 7e4bca2bd66a55c70b6a9365e2f1debcc22bb119..2abf1d5a10c7e8645ca51d842390d50d5fbce5dd 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -29,7 +29,9 @@ template + typename CDEElementwiseOperation, + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index ab9e6adb41963df1f9cd56ca2a7ed3056275defe..7296e4faaa30c56cab55ac5140b94a20c45782b1 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -20,7 +20,9 @@ template + typename OutElementwiseOperation, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6febf702f9e6b44c3fe7f4ad35f15e5507dea159 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedConvBwdWeightMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsLayout::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..184efbbd68ecea5b3c2b36f52764690e3ad316da --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +using is_tuple = decltype(std::declval().IsTuple()); + +/** + * \brief Grouped Convolution Forward + * + * \details + * input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]... + * input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]... + * input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... + * output : output image E[G, N, K, Ho, Wo] + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed). + * \tparam BComputeType Compute data type for B tensor (default: AComputeType). + */ +template ::value, + Number<0>, + ADataType>()), // AComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeType = AComputeType> +struct DeviceGroupedConvFwdMultipleABD : public BaseOperator +{ + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + + // If DataType is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + + /** + * \brief Make argument pointer for grouped conv fwd. + * + * \param p_a A pointer to the input (std::array with + pointers for multiple A). + * \param p_b A pointer to the weight (std::array with + pointers for multiple B). + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the output. + * \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d). + * \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d). + * \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d). + * \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d). + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Input left paddings. + * \param input_right_pads Input right paddings. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ + virtual std::unique_ptr MakeArgumentPointer( + APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr + MakeArgumentPointer(APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 1cc30fd9e6ad10336a49a261b3ee7c47f9053435..daf397189a4ed91d190a463ef3ec4212837e5abf 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -3,21 +3,33 @@ #pragma once -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" namespace ck { namespace tensor_operation { namespace device { -// Convolution Forward: -// input : input image A[G, N, C, Hi, Wi], -// input : weight B[G, K, C, Y, X], -// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... -// output : output image E[G, N, K, Ho, Wo] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) +/** + * \brief Grouped Convolution Forward + * + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceGroupedConvFwdMultipleABD. + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + */ template -struct DeviceGroupedConvFwdMultipleD : public BaseOperator -{ - static constexpr index_t NumDTensor = DsDataType::Size(); - - static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); - - virtual std::unique_ptr MakeArgumentPointer( - const void* p_a, // input image - const void* p_b, // weight - const std::array& p_ds, - void* p_e, // output image - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; + typename CDEElementwiseOperation, + typename ComputeType = + decltype(UnpackDataType::value, + Number<0>, + ADataType>())> // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed +using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcb2ba6a4d7be839c11cf9794bb7beccf7845d3c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmKernelArgument +{ + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; +}; + +template +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..59483cb89fa6b5beff16fefdeb88beb304669d87 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct GemmMultiABDDesc +{ + ck::index_t M_, N_, K_; + + std::vector stride_As_; + std::vector stride_Bs_; + std::vector stride_Ds_; + + ck::index_t stride_C_; +}; + +/* + * \brief Grouped Gemm Multi ABD + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam AsLayout A layouts (tuple). + * \tparam BsLayout B layouts (tuple). + * \tparam DsLayout Ds layouts (tuple). + * \tparam ELayout Output layout. + * \tparam AsDataType A data types (tuple). + * \tparam BsDataType B data types (tuple). + * \tparam DsDataType D data types (tuple). + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation C elementwise operation. + */ +template +struct DeviceGroupedGemmMultiABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor"); + static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor"); + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); + + /* + * \brief Make argument pointer for grouped gemm multi abd. + * + * \param p_as A pointers to the A. + * \param p_bs A pointers to the B. + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the E. + * \param gemm_desc Gemm descriptors for each group. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(std::vector>& p_as, + std::vector>& p_bs, + std::vector>& p_ds, + std::vector& p_e, + std::vector& gemm_desc, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..05c3e796a8b74315585506cf66adab4408e8bade --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm_multi_abd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmMultiABDKernelArgument +{ + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + index_t StrideE; +}; + +/* + * \brief Grouped Gemm Multi ABD Fixed NK + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam AsLayout A layouts (tuple). + * \tparam BsLayout B layouts (tuple). + * \tparam DsLayout Ds layouts (tuple). + * \tparam ELayout Output layout. + * \tparam AsDataType A data types (tuple). + * \tparam BsDataType B data types (tuple). + * \tparam DsDataType D data types (tuple). + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation C elementwise operation. + */ +template +struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d91eac07302fc31a662a9c55c5a5a6d9894bd7d0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmMultipleDKernelArguments +{ + __host__ __device__ + GroupedGemmMultipleDKernelArguments(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +template +struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c1030f31ccc187192aa5e552fab79e41e785861d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmTileLoopKernelArguments +{ + __host__ __device__ + GroupedGemmTileLoopKernelArguments(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +template +struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp deleted file mode 100644 index bf81ed9f5b1c2e6b46208ffede15370a3ae78088..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// For pooling which used indexable operation, such as MaxPool, MinPool...etc -template -struct DeviceIndexPoolBwd : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const void* p_dout, - const void* p_indices, - void* p_din, - index_t dout_length, - index_t din_length, - std::vector window_lengths, - std::vector window_strides) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a4a9cac1e8ea78b6683b1c4ce9629a8c87b9802 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// For pooling which used indexable operation, such as MaxPool, MinPool...etc +template +struct DeviceMaxPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp deleted file mode 100644 index 1f178f9fcb65ffdd7ab09146d94131c1d5c993f2..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ /dev/null @@ -1,62 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -template -struct DeviceNormalization : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector yStrides, - const std::vector reduceDims, - double epsilon, - const void* p_x, - const void* p_gamma, - const void* p_beta, - void* p_y, - void* p_savedMean, - void* p_savedInvVar, - YElementwiseOperation y_elementwise_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -using DeviceNormalizationPtr = std::unique_ptr>; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..327acfeb53e9f468bfdef4713a8e3b748f876eca --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdDataPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a8804cbf0e86b48b9477978a4691f6159af3f09 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdGammaBeta : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdGammaBetaPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d252ad1d98f57b962e262efef74b5ae20e8a071c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_savedMean, + void* p_savedInvVar, + YElementwiseOperation y_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..67286e5541e8103747f371ba2bc9ae17c6ed1e98 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceMultiD : public BaseOperator +{ + static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array, NumDTensor> DsLengths, + const std::array, NumDTensor> DsStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + const void* in_dev, + const std::array ds_dev, + void* out_dev, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceReduceMultiDPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp b/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp deleted file mode 100644 index 16ca582b89c70a2819373fea6962daf34ff2b8a2..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck { -namespace tensor_operation { -namespace device { - -enum struct GemmDlAlgorithm -{ - Default, // Uses DOT vector instructions - Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/helper.hpp b/include/ck/tensor_operation/gpu/device/helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c0e5ce9dd3ca778af571d6f4d1a3d455089d65a7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/helper.hpp @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include +#include + +// functions to return the corresponding structs based on generated template parameters + +using layouts = std::variant; +// return the layout type: currently this is the only type supported in MIOpen +auto layout_type(std::string type) +{ + if(type == "ck::tensor_layout::convolution::NHWGK") + { + return ck::tensor_layout::convolution::NHWGK{}; + } + throw std::runtime_error("Incorrect layout"); +} +// return the right gemm spec based on the generated template parameters +ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type) +{ + if(type == "ck::tensor_operation::device::GemmSpecialization::Default") + { + return ck::tensor_operation::device::GemmSpecialization::Default; + } + if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding") + { + return ck::tensor_operation::device::GemmSpecialization::MNKPadding; + } + throw std::runtime_error("Incorrect gemm spec: " + type); +} + +// return the type of convolution +ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type) +{ + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + } + if(type == + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + } + throw std::runtime_error("Incorrect conv spec: " + type); +} + +// Function to call on MatrixPadder via a wrapper struct +// NOTE: CK only uses MNKPadding for forward convolution +template +auto pad(ck::index_t mpb, + ck::index_t npb, + ck::index_t kpb, + ck::tensor_operation::device::GemmSpecialization gemm, + CDesc_MRaw_NRaw conv) +{ + if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding) + { + ck::tensor_operation::device::MatrixPadder< + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + ck::index_t, + ck::index_t, + ck::index_t> + a; + a.MPerTile_ = mpb; + a.NPerTile_ = npb; + a.KPerTile_ = kpb; + auto tmp = grid_desc(a, conv); + return tmp; + } + throw std::runtime_error("Incorrect template parameters, check gemm spec"); +} + +// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num +// dims +// FIXME: add a way to properly pass in the layout +auto transform_conv(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_3d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_1d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + throw std::runtime_error("Incorrect dims or conv spec"); +} + +template +auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder) +{ + if(m_per_block == 32 && n_per_block == 64) + { + auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 32 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 256) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 256 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + throw std::runtime_error("Incorrect template parameters"); +} + +// wrapper functions by dims to get grid size - uses above 3 functions +// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions +auto get_launch_params_1d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params_3d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..180e32c8b6b8ee41577b3f9700614990a102b2ee --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1021 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ + + device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AsPointer, // tuples if multi AB, pointers if no + BsPointer, + DsPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfBatch, + HasMainKBlockLoop, + isMultiA, + isMultiB>(p_as_grid, + p_bs_grid, + p_ds_grid, + *p_e_grid, + a_element_op, + b_element_op, + cde_element_op, + batch_count, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + compute_ptr_offset_of_batch); +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument + { + __device__ __host__ Argument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + ck::Array a_g_n_c_wis_lengths_; + ck::Array a_g_n_c_wis_strides_; + ck::Array b_g_k_c_xs_lengths_; + ck::Array b_g_k_c_xs_strides_; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_; + ck::Array e_g_n_k_wos_lengths_; + ck::Array e_g_n_k_wos_strides_; + ck::Array conv_filter_strides_; + ck::Array conv_filter_dilations_; + ck::Array input_left_pads_; + ck::Array input_right_pads_; + }; + + static __device__ __host__ bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7fca3e29886ebeac37842fc9d35eed6fa77e888e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp @@ -0,0 +1,523 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// In and Din = [N, C, Hi, Wi] +// Out and Dout = [N, C, Ho, Wo] +// Out = AvgPool2dFwd(In) +// Din = AvgPool2dBwd(Dout) +// Pooling dimension = H, W +template +struct DeviceAvgPool2dBwd_NHWC_NHWC : public DeviceAvgPoolBwd<2, + DOutDataType, + DInDataType, + tensor_layout::convolution::NHWC, + tensor_layout::convolution::NHWC> +{ + + static constexpr ck::index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto + Make2DGridDescriptor_Out_M_K_In_M(const std::vector& dout_n_c_wos_lengths, + const std::vector& din_n_c_wos_length, + const std::vector& dout_n_c_wos_strides, + const std::vector& din_n_c_wos_strides, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads, + const std::vector& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = dout_n_c_wos_lengths[0]; + const index_t C = dout_n_c_wos_lengths[1]; + const index_t Ho = dout_n_c_wos_lengths[2]; + const index_t Wo = dout_n_c_wos_lengths[3]; + + const index_t Hi = din_n_c_wos_length[2]; + const index_t Wi = din_n_c_wos_length[3]; + + const index_t Y = window_lengths[0]; + const index_t X = window_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = window_strides[0]; + const index_t ConvStrideW = window_strides[1]; + + const index_t ConvDilationH = window_dilations[0]; + const index_t ConvDilationW = window_dilations[1]; + + const index_t Ni_stride = dout_n_c_wos_strides[0]; + const index_t Ci_stride = dout_n_c_wos_strides[1]; + const index_t Ho_stride = dout_n_c_wos_strides[2]; + const index_t Wo_stride = dout_n_c_wos_strides[3]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on Tildes that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // ReduceK is different for each Reduce + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // Problem size of reduction kernel + const index_t MRaw = N * HTildeSlice * WTildeSlice * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = YDotSlice * XDotSlice; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride)); + + // Out[ReduceM, ReduceK] + const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)), + make_merge_transform(make_tuple(YDotSlice, XDotSlice))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( + out_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // In[ReduceM] + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(din_n_c_wos_strides[0], + din_n_c_wos_strides[2], + din_n_c_wos_strides[3], + din_n_c_wos_strides[1])); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_grid_desc_reducemraw = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto in_grid_desc_reducem = + transform_tensor_descriptor(in_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); + } + + using DoutDinGridDesc = decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0}, + {0, 0}, + {0, 0}, + {0, 0}, + {0, 0}, + {0, 0})); + + using DoutGridDesc_M_K = remove_cvref_t>; + using DinGridDesc_M = remove_cvref_t>; + + // FIXME + // for NHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Div = tensor_operation::element_wise::UnaryDivide; + + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + DInDataType* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_dout_grid_{p_dout}, + p_din_grid_{p_din}, + dout_n_c_wos_lengths_{dout_n_c_wos_lengths}, + din_n_c_wos_length_{din_n_c_wos_length}, + dout_n_c_wos_strides_{dout_n_c_wos_strides}, + din_n_c_wos_strides_{din_n_c_wos_strides}, + num_reduce_{1}, + div_element_op_{window_lengths[0] * window_lengths[1]} + { + std::vector Tildes(NDimSpatial); + for(int i = 0; i < NDimSpatial; ++i) + { + int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]); + Tildes[i] = window_strides[i] / GcdStrideDilation; + num_reduce_ *= Tildes[i]; + } + + for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde) + { + const auto YDotSlice = + math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]); + const auto XDotSlice = + math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]); + + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto dout_din_grid_desc = + Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]); + din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]); + } + } + } + + const DOutDataType* p_dout_grid_; + DInDataType* p_din_grid_; + std::vector dout_n_c_wos_lengths_; + std::vector din_n_c_wos_length_; + std::vector dout_n_c_wos_strides_; + std::vector din_n_c_wos_strides_; + + int num_reduce_; + std::vector dout_grid_desc_m_k_container_; + std::vector din_grid_desc_m_container_; + + Div div_element_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + for(index_t i = 0; i < arg.num_reduce_; i++) + { + const auto kernel = kernel_reduce_threadwise; + + ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0); + const index_t grid_size = (M / M_BlockTileSize); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.dout_grid_desc_m_k_container_[i], + arg.din_grid_desc_m_container_[i], + PassThrough{}, + arg.div_element_op_, + float(1), + arg.p_dout_grid_, + nullptr, + float(0), + arg.p_din_grid_, + nullptr); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr index_t Rank = NDimSpatial + 2; + int doutFastestDim = -1; + int dinFastestDim = -1; + + for(int i = 0; i < Rank; ++i) + { + if(arg.dout_n_c_wos_strides_[i] == 1) + doutFastestDim = i; + if(arg.din_n_c_wos_strides_[i] == 1) + dinFastestDim = i; + } + if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1)) + { + return false; + } + if(doutFastestDim == -1 || dinFastestDim == -1) + { + if constexpr(InSrcOutDstVectorSize != 1) + return false; + } + else + { + if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0) + return false; + if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0) + return false; + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override + { + constexpr index_t Rank = NDimSpatial + 2; + + if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank || + dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank) + { + throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or " + "[dout|din]_n_c_wos_lengths is not equal to Rank"); + } + + if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || + input_right_pads.size() != NDimSpatial) + { + throw std::runtime_error( + "dimension of [window_lengths, window_strides, window_dilations, input_left_pads, " + "input_right_pads] is not equal to Rank"); + } + return std::make_unique(static_cast(p_dout), + static_cast(p_din), + dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceAvgPool2dBwd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 4d599e8017991b4f1b1905de9acd23a219ac67f9..ab3f3856aaab2eb54e7759a05be0ba82713822dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -62,10 +62,10 @@ template struct DeviceBatchedContractionMultipleD_Wmma_CShuffle @@ -123,15 +123,37 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] - static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, - const std::vector& a_gs_ms_ks_strides_vec) + static auto MakeAGridDescriptor(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) { assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); @@ -158,36 +180,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // lengths for K0, K1, ... const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); - if constexpr(ASpec == TensorSpecialization::Packed) + const auto a_grid_desc_m_k = [&]() { + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) { - auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); - auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); - const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( - make_tuple(M, K), - make_tuple(a_ms_ks_strides[Number{}], - a_ms_ks_strides[Number{}])); - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } else { - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] - const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); - - // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] - const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( - a_grid_desc_ms_ks, - make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), - make_tuple(mDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); } } // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] - static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, - const std::vector& b_gs_ns_ks_strides_vec) + static auto MakeBGridDescriptor(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) { assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); @@ -214,30 +272,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // lengths for N0, N1, ... const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); - if constexpr(BSpec == TensorSpecialization::Packed) + const auto b_grid_desc_n_k = [&]() { + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) { - auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); - auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); - const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( - make_tuple(N, K), - make_tuple(b_ns_ks_strides[Number{}], - b_ns_ks_strides[Number{}])); - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } else { - // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] - const auto b_grid_desc_ns_ks = - make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); - - // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] - const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( - b_grid_desc_ns_ks, - make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), - make_tuple(nDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); } } @@ -393,8 +487,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } // Gridwise descriptor, mapping to whole given provblem. - using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); - using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); @@ -449,45 +541,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / K1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / K1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); - using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); + using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {})); + using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -496,8 +554,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle DsDataType, EDataType, // InMemory Data Descriptor - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, + AGridDesc, + BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, // ElementwiseOp Family @@ -508,9 +566,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, - MPerWMMA, - NPerWMMA, + KPerBlock, + MPerWmma, + NPerWmma, K1, MRepeat, NRepeat, @@ -523,6 +581,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -531,6 +590,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, @@ -564,16 +624,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle p_b_grid_{static_cast(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{}, - b_grid_desc_n_k_{}, + a_grid_desc_{}, + b_grid_desc_{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{}, ds_grid_desc_g_m_n_{ DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{}, e_grid_desc_mblock_mperblock_nblock_nperblock{}, block_2_ctile_map_{}, @@ -600,10 +658,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle p_ds_grid_(i) = static_cast(p_ds_grid[i]); }); - a_grid_desc_m_k_ = - DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - b_grid_desc_n_k_ = - DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); @@ -611,9 +667,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); - b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); - block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); ds_grid_desc_mblock_mperblock_nblock_nperblock = @@ -644,16 +697,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EDataType* p_e_grid_; // Tensor Descriptors - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -686,6 +736,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // Batch Offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // for checking vector load/store + // index_t MRaw_; + // index_t NRaw_; + // index_t KRaw_; }; // Invoker @@ -700,8 +755,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -712,8 +776,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle BDataType, typename GridwiseOp::DsGridPointer, EDataType, - DeviceOp::AGridDesc_K0_M_K1, - DeviceOp::BGridDesc_K0_N_K1, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, AElementwiseOperation, @@ -733,8 +797,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.p_ds_grid_, arg.p_e_grid_, G, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, + arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock, arg.a_element_op_, @@ -770,11 +834,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { + printf("DeviceOp: Arch check failure\n"); return false; } } @@ -783,12 +847,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle return false; } - if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, arg.block_2_ctile_map_)) { + printf("GridwiseOp: Validity check failure\n"); return false; } @@ -801,17 +866,23 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle if constexpr(ABlockTransferSrcVectorDim == 1) { if(!(arg.a_mz_stride_ == 1 && - arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access A-m check failure\n"); return false; } } else { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + if(!(arg.a_kz_stride_ == 1)) { - return false; + index_t LastK = + AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); + if(LastK % ABlockTransferSrcScalarPerVector == 0) + { + printf("DeviceOp: Vector Access A-k check failure\n"); + return false; + } } } @@ -819,16 +890,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle if constexpr(BBlockTransferSrcVectorDim == 1) { if(!(arg.b_nz_stride_ == 1 && - arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access B-n check failure\n"); return false; } } else { if(!(arg.b_kz_stride_ == 1 && - arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access B-k check failure\n"); return false; } } @@ -842,6 +915,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) { + printf("DeviceOp: Vector Access D-n check failure\n"); valid_d_access = false; } }); @@ -858,6 +932,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle 0) || CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) { + printf("DeviceOp: Vector Access E-n check failure\n"); return false; } @@ -968,14 +1043,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock << ", " + << KPerBlock << ", " << K1 << ", " - << MPerWMMA << ", " - << NPerWMMA << ", " + << MPerWmma << ", " + << NPerWmma << ", " << MRepeat << ", " << NRepeat << ">" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index fe971171811b8a1ecba4b53ed509a9b546d20ffb..64aa398d531e634c54c3d55d591789f72943bd0e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -543,9 +543,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index d6b6405bb70af4699178b97f6a5b5052f136f4a5..d06eab1264684b595dc87dc1cf91a9f2b44a5056 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -75,7 +75,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -331,8 +331,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, // DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 3dbe8c67226ec4e4f0dc5bf4ac02ac49ae73cf46..e950169ccfe58887351c4bfbc82af85a061287cf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -61,7 +61,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 3df2ee38f2915b1112186e2a3f1fbfd64a3b174a..d6b92bc97a8d3eed0d0ce50ad26a9169c498671a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -84,7 +84,7 @@ __global__ void { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -324,8 +324,12 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD{}([&](auto i) { using D0Layout = remove_cvref_t>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..314ecdf76e59da5ccc9e59a67691968426fa4b1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,1014 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); + + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 + : public DeviceBatchedGemmV2MultiD +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideA_) * g_idx; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideB_) * g_idx; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset_; + + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + ds_offset_[i] = static_cast(BatchStrideDs_[i]) * g_idx; + }); + + return ds_offset_; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideC_) * g_idx; + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + const std::array BatchStrideDs_; + index_t BatchStrideC_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_, + index_t BatchStrideA_, + index_t BatchStrideB_, + const std::array& BatchStrideDs_, + index_t BatchStrideE_, + index_t Batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument{p_a_grid_, + p_b_grid_, + p_ds_grid_, + p_e_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideE_, + 1, + a_element_op_, + b_element_op_, + c_element_op_}, + Batch{Batch_}, + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_} + { + } + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) * arg.Batch; + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) * arg.Batch; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::array{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::array{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::array{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::array{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = alpha; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size}; + std::array qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length}; + std::array v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size}; + std::array c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("sequence_length: %d\n", sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("qkv_gap: %d\n", qkv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_qkv_grid + 0 * qkv_gap + a_batch_offset, + p_qkv_grid + 1 * qkv_gap + b0_batch_offset, + p_qkv_grid + 2 * qkv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_qkv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Cross-Attention +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + std::array k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size}; + std::array k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length}; + std::array v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("q_sequence_length: %d\n", q_sequence_length); + printf("k_sequence_length: %d\n", kv_sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("kv_gap: %d\n", kv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_q_grid + a_batch_offset, + p_kv_grid + 0 * kv_gap + b0_batch_offset, + p_kv_grid + 1 * kv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_q_grid; + ignore = p_kv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = q_sequence_length; + ignore = kv_sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + struct SelfAttnArg : public BaseArgument + { + SelfAttnArg(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_qkv_grid_{p_qkv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + sequence_length_{sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_qkv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return SelfAttnArg{ + p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha}; + } + + struct CrossAttnArg : public BaseArgument + { + CrossAttnArg(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_q_grid_{p_q_grid}, + p_kv_grid_{p_kv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + q_sequence_length_{q_sequence_length}, + kv_sequence_length_{kv_sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_q_grid_; + const B0DataType* p_kv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t q_sequence_length_; + index_t kv_sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeCrossAttnArgument(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return CrossAttnArg{p_q_grid, + p_kv_grid, + p_out_grid, + batch_size, + q_sequence_length, + kv_sequence_length, + head_count, + head_size, + alpha}; + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct SelfAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::SelfAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_self_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_qkv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; } + + // Invoker + struct CrossAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::CrossAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_cross_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_q_grid_, + arg.p_kv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.q_sequence_length_, + arg.kv_sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; } + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = + kernel_batched_gemm_softmax_gemm_wmma_cshuffle; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index c38b7e1c5bc322a1765fdff02c0f26db205bb5e2..e178b8f5252781ead149f5d2b78f8fc53125a3af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -68,7 +68,7 @@ __global__ void const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DEBUG_LOG - arg.Print(); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + arg.Print(); + } if(!ck::is_xdl_supported()) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 12ed84d63c0d52b9fe705f5811d89fa5ebc08018..12fb7775ad5a53c285e695b4fed8086d22bdc63e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -65,7 +65,7 @@ __global__ void const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 408c56d490e52aa30742cf29ed0b7df52c718c76..6be2ffbdd781d27cc43845b7da311bc8a929e42b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -53,7 +53,7 @@ __global__ void kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm" << " NumGemmKPrefetchStage: " << NumGemmKPrefetchStage << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index f46237e005642d433c29f5ab0f027a224043dbef..3b62cf10a361144693255049520e891add60635a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp index ad8e7956033db31db2992d6fa9f86cf16a29b150..e7e4668d92e72d985168462358eaa088b5dc385f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp index b826793c27be946a99d0d96a99d745fd6c618822..c3e0837722cfd3f2cec9f2c20a19165943a9ef6d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index a095521161c76d7bc831c76987a07691c415895c..8843e520a606037b9aa0e9bd59f90274cb4c24ec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ #include "ck/tensor_operation/gpu/device/device_cgemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -80,42 +80,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - static constexpr auto MPerThread = Number<4>{}; + static constexpr index_t MPerThread = + MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t NPerThread = + NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + static constexpr auto AScalarPerVector = Number<4>{}; static constexpr auto BScalarPerVector = Number<4>{}; static constexpr auto CScalarPerVector = Number<4>{}; - template - static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + template + static auto PadDescriptor_M_N(Desc_M_N desc) { - const auto M = desc_m.GetLength(I0); - const index_t loop_step = gridSize * blockSize * MPerThread; - const auto pad = math::integer_least_multiple(M, loop_step) - M; - const auto desc_m_pad = - transform_tensor_descriptor(desc_m, - make_tuple(make_right_pad_transform(M, pad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return desc_m_pad; + const auto M = desc.GetLength(I0); + const auto N = desc.GetLength(I1); + const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M; + const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N; + + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_right_pad_transform(M, pad_M), make_right_pad_transform(N, pad_N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return padded_desc; } - static auto MakeDescriptor_M(const std::vector& lengths, - const std::vector& strides, - index_t gridSize, - index_t blockSize) + static auto MakeDescriptor_M_N(const std::vector& lengths, + const std::vector& strides) { auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{}); auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{}); // nd desc - [s0, s1, s2, ...] - const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); - const auto desc_m = transform_tensor_descriptor( - desc, - make_tuple(make_merge_transform(tupleOfShape)), - make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})), - make_tuple(Sequence<0>{})); - - return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + return PadDescriptor_M_N(desc); } // GridwiseGemm @@ -166,7 +165,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1})); // Argument struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem @@ -195,17 +194,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle p_c_grid_imag{p_c_grid_imag_}, p_aux_grid{p_workspace} { - const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_)); - if constexpr(is_same::value) { - c_grid_desc_m = - DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize); + c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1}); } else if constexpr(is_same::value) { - c_grid_desc_m = - DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize); + c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_}); } p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_); @@ -220,7 +215,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CDataType* p_c_grid_imag; CDataType* p_aux_grid; CDataType* p_aux_2_grid; - CGridDesc_M c_grid_desc_m; + CGridDesc_M_N c_grid_desc_m_n; }; // Invoker @@ -248,40 +243,63 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle using Add = ck::tensor_operation::element_wise::Add; using Subtract = ck::tensor_operation::element_wise::Subtract; - using GridwiseBinAdd = - GridwiseElementwise_1D, - Tuple, - Tuple, - Tuple, - Add, - MPerThread, - Sequence, - Sequence>; + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseBinAdd = GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Add, + BlockSize, + MPerBlock, + NPerBlock, + MPerThread, + NPerThread, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; using GridwiseBinSubtract = - GridwiseElementwise_1D, - Tuple, - Tuple, - Tuple, - Subtract, - MPerThread, - Sequence, - Sequence>; - - const auto add_kernel = kernel_elementwise_1d, - Tuple, - Tuple, - Tuple, - Add>; + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Subtract, + BlockSize, + MPerBlock, + NPerBlock, + MPerThread, + NPerThread, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + const index_t M = arg.c_grid_desc_m_n.GetLength(I0); + const index_t N = arg.c_grid_desc_m_n.GetLength(I1); + const auto block_2_tile_map = Block2TileMap(M, N); + + const auto add_kernel = kernel_elementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Add>; const auto subtract_kernel = - kernel_elementwise_1d, - Tuple, - Tuple, - Tuple, - Subtract>; + kernel_elementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Subtract>; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { @@ -318,11 +336,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), - make_tuple(arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), make_tuple(const_cast(arg.p_aux_grid), const_cast(arg.p_aux_2_grid)), make_tuple(arg.p_c_grid_real), + block_2_tile_map, Subtract{}); ave_time += launch_and_time_kernel(stream_config, @@ -352,11 +371,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), - make_tuple(arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), make_tuple(const_cast(arg.p_aux_grid), const_cast(arg.p_aux_2_grid)), make_tuple(arg.p_c_grid_imag), + block_2_tile_map, Add{}); } else @@ -394,11 +414,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), - make_tuple(arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), make_tuple(const_cast(arg.p_aux_grid), const_cast(arg.p_aux_2_grid)), make_tuple(arg.p_c_grid_real), + block_2_tile_map, Subtract{}); ave_time += launch_and_time_kernel(stream_config, @@ -428,11 +449,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), - make_tuple(arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), make_tuple(const_cast(arg.p_aux_grid), const_cast(arg.p_aux_2_grid)), make_tuple(arg.p_c_grid_imag), + block_2_tile_map, Add{}); } @@ -576,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle [[maybe_unused]] index_t K, [[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideB, - index_t StrideC) override + index_t StrideC) const override { return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); } + + std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override + { + const auto* parg = dynamic_cast(base_arg); + + if(!parg) + { + std::ostringstream err; + err << "Provided argument pointer is not of an Argument class!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return GetWorkspaceSize( + parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4203e0313a0ed9213841e2474a603a889b039a5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -0,0 +1,646 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Column to Image: +// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C] +// output : input image [G, N, Di, Hi, Wi, C] +// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C] +// output : input image [N, Di, Hi, Wi, G, C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceColumnToImageImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number{}; + static constexpr auto XIdx = Number{}; + + static constexpr auto spatial_offset = Number<3>{}; + + using ConvToGemmFwdTransformer = + TransformConvFwdToGemm; + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Calculate number of independent filters for given conv params + static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len, + const index_t left_pad, + const index_t right_pad, + const index_t filter_len, + const index_t filter_stride, + const index_t filter_dilation, + const index_t image_offset) + { + const index_t x_eff = (filter_len - 1) * filter_dilation + 1; + const index_t next_filter_padded = + math::integer_divide_ceil(x_eff, filter_stride) * filter_stride; + // If filter_stride >= x_eff then each filter is independent + const index_t independent_filter_stride = + filter_stride >= x_eff ? filter_stride : next_filter_padded; + const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff; + // There are no independent filters + if(w_eff < 0) + return 0; + const index_t independent_kernels_num = w_eff / independent_filter_stride + 1; + return independent_kernels_num; + } + + // Make column form descriptor + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& gemm_g_m_k_strides, + const std::array& independent_filters, + const std::array& effs) + { + const index_t DoHoWo = ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2]; + // Calculate the appropriate stride for each set of independent filters + // in each dimension + const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * + gemm_g_m_k_strides[I1]; + const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * + output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1]; + const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * + output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * + gemm_g_m_k_strides[I1]; + // Create descriptor for independent filters in each dimension and + // then merge them into column form + if constexpr(NDimSpatial == 1) + { + const auto desc_gemm_form = + make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), + make_tuple(NStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 2) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), + make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 3) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx], + CZYX), + make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + } + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const std::array& image_offsets, + const std::array& independent_filters, + const std::array& effs) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + // Calculate descriptor only for independent filters + copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + // Modify pads to apply offsets + std::array input_left_pads_with_offset; + for(index_t i = 0; i < NDimSpatial; i++) + { + input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]); + } + // Modify input spatial lengths to apply offsets + for(index_t i = 0; i < NDimSpatial; i++) + { + a_g_n_c_wis_lengths[i + spatial_offset] -= + math::max(0, image_offsets[i] - input_left_pads[i]); + } + + // Strides to next independent filters + std::array independent_filter_strides; + for(index_t i = 0; i < NDimSpatial; i++) + { + index_t independent_filter_stride = + math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i]; + // If conv stride is greater than whole filter size, use conv stride + independent_filter_strides[i] = conv_filter_strides[i] >= effs[i] + ? conv_filter_strides[i] + : independent_filter_stride; + } + + ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + // conv_filter_strides, + independent_filter_strides, + conv_filter_dilations, + input_left_pads_with_offset, + input_right_pads}; + + // Calculate image form descriptor for the modified convolution problem + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + InputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0]; + + const index_t x_eff = + (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1; + const index_t y_eff = + NDimSpatial < 2 + ? I1 + : (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1; + const index_t z_eff = + NDimSpatial < 3 + ? I1 + : (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1; + + // Iterate over sets of independent filters + for(int z_img_offset = 0; z_img_offset < z_eff; + z_img_offset += conv_filter_strides[ZIdx]) + { + for(int y_img_offset = 0; y_img_offset < y_eff; + y_img_offset += conv_filter_strides[YIdx]) + { + for(int x_img_offset = 0; x_img_offset < x_eff; + x_img_offset += conv_filter_strides[XIdx]) + { + + std::array image_offsets; + std::array effs; + // Calculate the starting offset for a given set of + // independent filters + if constexpr(NDimSpatial == 1) + { + image_offsets = {x_img_offset}; + effs = {x_eff}; + } + if constexpr(NDimSpatial == 2) + { + image_offsets = {y_img_offset, x_img_offset}; + effs = {y_eff, x_eff}; + } + else if constexpr(NDimSpatial == 3) + { + image_offsets = {z_img_offset, y_img_offset, x_img_offset}; + effs = {z_eff, y_eff, x_eff}; + } + + std::array independent_filters; + for(index_t i = 0; i < NDimSpatial; i++) + { + independent_filters[i] = + GetNumberOfIndependentFilters(input_spatial_lengths[i], + input_left_pads[i], + input_right_pads[i], + filter_spatial_lengths[i], + conv_filter_strides[i], + conv_filter_dilations[i], + image_offsets[i]); + } + const index_t independent_filters_acum = ck::accumulate_n( + independent_filters.begin(), NDimSpatial, 1, std::multiplies<>()); + if(independent_filters_acum <= 0) + continue; + + const auto in_grid_desc_m_k = + MakeInputDescriptor_M_K(N, + C, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + gemm_g_m_k_strides, + independent_filters, + effs); + const auto out_grid_desc_m_k = + MakeOutDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + image_offsets, + independent_filters, + effs); + in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k); + out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k); + + const index_t x_idx = x_img_offset / conv_filter_strides[XIdx]; + const index_t y_idx = y_img_offset / conv_filter_strides[YIdx]; + const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx]; + + const index_t x_offset_with_pad = + math::max(0, x_img_offset - input_left_pads[XIdx]); + const index_t y_offset_with_pad = + math::max(0, y_img_offset - input_left_pads[YIdx]); + const index_t z_offset_with_pad = + math::max(0, z_img_offset - input_left_pads[ZIdx]); + + // Memory offsets to next set of independent filters, + // move to independent filters in each dimension + const index_t in_offset = + (x_idx + y_idx * output_spatial_lengths[XIdx] + + z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) * + gemm_g_m_k_strides[I1]; + // Move to independent filters in appropriate dimensions + const index_t out_offset = + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + + y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] + + z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx]; + + const InputDataType* p_in_with_offset = + static_cast(p_in) + in_offset; + OutputDataType* p_out_with_offset = + static_cast(p_out) + out_offset; + p_in_container_.push_back(p_in_with_offset); + p_out_container_.push_back(p_out_with_offset); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++) + { + std::cout << in_grid_desc_m_k_container_[i] << std::endl; + std::cout << out_grid_desc_m_k_container_[i] << std::endl; + } + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + std::vector in_grid_desc_m_k_container_; + std::vector out_grid_desc_m_k_container_; + + std::vector p_in_container_; + std::vector p_out_container_; + + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float elapsed_time = 0.f; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + // Execute each set of independent filters + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_container_[i]); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_; + elapsed_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_container_[i], + arg.p_in_container_[i], + arg.out_grid_desc_m_k_container_[i], + arg.p_out_container_[i], + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + } + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + using namespace tensor_layout::convolution; + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + bool valid = true; + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + valid &= GridwiseTensorRearrangeKernel::CheckValidity( + arg.in_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i]); + } + return valid; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceColumnToImage" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dae16612ccaeb92273895d2de599261f76c43e9c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,848 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD_Xdl_CShuffle + : public DeviceContractionMultipleABD +{ + using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_, + const std::vector& a_ms_ks_strides_) + { + assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK && + a_ms_ks_strides_.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]); + }, + Number{}); + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_, + const std::vector& b_ns_ks_strides_) + { + assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK && + b_ns_ks_strides_.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]); + }, + Number{}); + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_, + const std::vector& e_ms_ns_strides_) + { + assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN && + e_ms_ns_strides_.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto + MakeDsGridDescriptor_M_N(const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AsGridDesc_M_K = remove_cvref_t; + using BsGridDesc_N_K = remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t; + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + // using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + // using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + + // for sanity check of vector memory access + for(index_t i = 0; i < NumATensor; ++i) + { + tie(as_continous_dim_[i], as_max_read_elems_[i]) = + CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + } + + for(index_t i = 0; i < NumBTensor; ++i) + { + tie(bs_continous_dim_[i], bs_max_read_elems_[i]) = + CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + } + + for(index_t i = 0; i < NumDTensor; ++i) + { + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = + CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + } + + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); + } + + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + std::array as_continous_dim_; + std::array bs_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; + + std::array as_max_read_elems_; + std::array bs_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + bool valid_as_access = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + const bool valid_a_vector_size = + arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1; + const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + if(!((valid_a_vector_size && valid_a_access_dim) || + ABlockTransferSrcScalarPerVector == 1)) + { + valid_as_access = false; + } + }); + if(!valid_as_access) + { + return false; + } + + bool valid_bs_access = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + const bool valid_b_vector_size = + arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1; + const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + if(!((valid_b_vector_size && valid_b_access_dim) || + BBlockTransferSrcScalarPerVector == 1)) + { + valid_bs_access = false; + } + }); + if(!valid_bs_access) + { + return false; + } + + bool valid_ds_access = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1; + if(!((valid_d_vector_size && valid_d_access_dim) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + valid_ds_access = false; + } + }); + if(!valid_ds_access) + { + return false; + } + + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = arg.e_continous_dim_ == 1; + if(!((valid_e_vector_size && valid_e_access_dim) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + d_ms_ns_lengths, + d_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides, + const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + as_ms_ks_lengths, + as_ms_ks_strides, + bs_ns_ks_lengths, + bs_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceContractionMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index deb938156d23d89297cb8e132dfcad3d5ad45c1c..f0f89f1d1b45b4128647f2ccee5e3a9580655ac4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -52,8 +53,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -145,7 +145,8 @@ template + typename ComputeDataType = ADataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceContractionMultipleD_Xdl_CShuffle : public DeviceContractionMultipleD + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; @@ -181,7 +183,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return generate_tuple([&](auto i) { return vec[i]; }, num); }; - const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); // dimension Ids for M0, M1, ... @@ -192,14 +194,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle typename arithmetic_sequence_gen::type{}; // lengths for M0, M1, ... - const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); // lengths for K0, K1, ... - const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( @@ -313,6 +315,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -379,7 +383,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b_grid, std::array p_ds_grid, void* p_e_grid, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -394,7 +398,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b_grid_{static_cast(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, @@ -407,13 +411,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - cde_element_op_{cde_element_op}, - a_mz_stride_{}, - a_kz_stride_{}, - b_nz_stride_{}, - b_kz_stride_{}, - ds_nz_stride_{}, - e_nz_stride_{} + cde_element_op_{cde_element_op} { // populate pointer, batch stride, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -444,18 +442,20 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } // for sanity check of vector memory access - a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; - a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; + tie(a_continous_dim_, a_max_read_elems_) = + CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); - b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; - b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; + tie(b_continous_dim_, b_max_read_elems_) = + CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = + CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); } - e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); } void Print() const @@ -495,15 +495,16 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Strides for the last M/N/K dimensions of A/B/Ds/E - // for sanity check of vector load/store - index_t a_mz_stride_; - index_t a_kz_stride_; - index_t b_nz_stride_; - index_t b_kz_stride_; - std::array ds_nz_stride_; - index_t e_mz_stride_; - index_t e_nz_stride_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + index_t a_continous_dim_; + index_t b_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; + + index_t a_max_read_elems_; + index_t b_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; }; // Invoker @@ -591,7 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } - if(ck::get_device_name() != "gfx90a" && std::is_same::value) + if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } @@ -610,65 +611,55 @@ struct DeviceContractionMultipleD_Xdl_CShuffle (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), "wrong!"); - // vector memory access of A: could be on M or AK1 dimension - if constexpr(ABlockTransferSrcVectorDim == 1) - { - if(!(arg.a_mz_stride_ == 1 && - arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else + const bool valid_a_vector_size = + arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1; + const bool valid_a_access_dim = + valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; + if(!(valid_a_vector_size && valid_a_access_dim)) { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } + return false; } - // vector memory access of B: could be on N or BK1 dimension - if constexpr(BBlockTransferSrcVectorDim == 1) + const bool valid_b_vector_size = + arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1; + const bool valid_b_access_dim = + valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; + if(!(valid_b_vector_size && valid_b_access_dim)) { - if(!(arg.b_nz_stride_ == 1 && - arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - if(!(arg.b_kz_stride_ == 1 && - arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } + return false; } - // vector memory access of Ds: always on NPerBlock dimension - bool valid_d_access = true; - + bool valid_ds_access = true; static_for<0, NumDTensor, 1>{}([&](auto i) { - if(!(arg.ds_nz_stride_[i] == 1 && - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = + arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; + if(!(valid_d_vector_size && valid_d_access_dim)) { - valid_d_access = false; + valid_ds_access = false; } }); - - if(valid_d_access == false) + if(!valid_ds_access) { return false; } - // vector memory access of E: always on NPerBlock dimension - if(!(arg.e_nz_stride_ == 1 && - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = + arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; + if(!(valid_e_vector_size && valid_e_access_dim)) { return false; } @@ -686,7 +677,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b, std::array p_ds, void* p_e, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -702,7 +693,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b, p_ds, p_e, - a_ms_ns_lengths, + a_ms_ks_lengths, a_ms_ks_strides, b_ns_ks_lengths, b_ns_ks_strides, @@ -723,7 +714,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b, std::array p_ds, void* p_e, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -739,7 +730,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b, p_ds, p_e, - a_ms_ns_lengths, + a_ms_ks_lengths, a_ms_ks_strides, b_ns_ks_lengths, b_ns_ks_strides, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b0db73fdd097210c4a31a9c456dd675b300ceb8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * Calculates the maximum number of subsequent elements of the fast changing dimension + * that are consecutive in memory. + * + * Example: + * NumDimM = 2, NumDimK = 3 + * A shape = [ 2, 3, 4, 5, 6] + * A strides = [360, 120, 30, 6, 1] + * | M | | K | + * It follows from strides that K is FCD and all the subsequent elements of K are consecutive + * in memory. + * But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be + * consecutive in memory. + * + * Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions. + */ +template +auto CalculateMaxRead(const std::vector& lengths, const std::vector& strides) +{ + if(lengths.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of lengths in " + << "device_contraction_utils.hpp" + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + if(strides.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of strides in " + << "device_contraction_utils.hpp" + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // Determine the beginning and end idx of the group representing the FCD. + index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1; + if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1) + { + // MZ or KZ are ones + bool dims1_are_ones = true; + for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++) + { + if(lengths[dim_idx] != 1) + { + dims1_are_ones = false; + } + } + + if(dims1_are_ones) + { + begin_idx = NumDim1; + end_idx = NumDim1 + NumDim2 - 1; + continous_dim = 1; + } + else + { + begin_idx = 0; + end_idx = NumDim1 - 1; + continous_dim = 0; + } + } + else if(strides[NumDim1 - 1] == 1) + { + begin_idx = 0; + end_idx = NumDim1 - 1; + continous_dim = 0; + } + else if(strides[NumDim1 + NumDim2 - 1] == 1) + { + begin_idx = NumDim1; + end_idx = NumDim1 + NumDim2 - 1; + continous_dim = 1; + } + else + { + // The dimension consecutive in memory is not the last dimension of any group, so only + // one element can be read/written at once. + consecutive_stride = 1; + continous_dim = 0; + return make_tuple(continous_dim, consecutive_stride); + } + + for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx) + { + if(strides[dim_idx] == consecutive_stride) + { + consecutive_stride *= lengths[dim_idx]; + } + else + { + break; + } + } + const index_t max_subsequent_elems = consecutive_stride; + return make_tuple(continous_dim, max_subsequent_elems); +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index e22c5a2aa514d4ace3b05074be06a539a7b54433..0b73317c5e6d38aabaeb6c0d7fb79c74c2a2e607 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "arg.a_grid_desc_k0_m_k1_container_{" - << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " - << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" - << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_container_{" - << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " - << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" - << std::endl; - - std::cout << "arg.c_grid_desc_m_n_container_{ " - << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " - << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" - << std::endl; + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + } } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index c9e8940edc4030be672d7c9def1225c9f650eee5..13eb23574f76118bc7021391d2746b5ae9737493 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -644,7 +644,7 @@ struct float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " @@ -664,9 +664,7 @@ struct << arg.input_left_pads_[1] << ", " << std::endl; std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " << arg.input_right_pads_[1] << ", " << std::endl; - } - { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; @@ -684,7 +682,6 @@ struct std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 28fceb428eff34628cc2f5c680bd9a45c0b856e6..28778d825ba7bd2fe8b73da67fc06bdc47861536 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " @@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X << arg.input_left_pads_[1] << ", " << std::endl; std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " << arg.input_right_pads_[1] << ", " << std::endl; - } - { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; @@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index ca291d3b11f5273caaf7900bf349eddceb108658..7fa231d4f438103d582844b4264a8091e0124847 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " @@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W << arg.input_left_pads_[1] << ", " << std::endl; std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " << arg.input_right_pads_[1] << ", " << std::endl; - } - { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; @@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W .GetLength(I5) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index ef94120f4e89f199018fc62c5af481fc0aca4f83..3be7313d2bb2ae42f500c09c570ddd9cd512342a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " @@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif + if(!GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index a8e586b20c38a40438636c5f9cafafb5c34409a9..6e6921351356740f899daab6b6343a46008072fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -56,7 +56,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) @@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 3178f73f4b530e67b399cd0ae71064a4b66fa837..1edae33be3ca2fc105fa142e323e00c3f3f5814e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_container_{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " @@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5) << " ) " << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], @@ -1393,9 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) + if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index ee3f0cea1b37f1de5ce3f53d6252e62024c8ddf8..de8f35a64070bc7e29b538288391533c3940646c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " @@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp deleted file mode 100644 index 02ef29e32ddc2bbd985fb4f46c6eb3825c00c11a..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp +++ /dev/null @@ -1,338 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/math.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" - -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/stream_utility.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -template -struct DeviceElementwise2dImpl : public DeviceElementwise -{ - static constexpr index_t NumDim = NumDim_m + NumDim_n; - - static constexpr int NumInput = InDataTypeTuple::Size(); - static constexpr int NumOutput = OutDataTypeTuple::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static_assert(NumInput == InScalarPerVectorSeq::Size() && - NumOutput == OutScalarPerVectorSeq::Size(), - "Tuple size is inconsistent with the number of in/out!"); - - static auto GenerateInDataTypePointerTuple() - { - return generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - - return static_cast(nullptr); - }, - Number{}); - }; - - static auto GenerateOutDataTypePointerTuple() - { - return generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - - return static_cast(nullptr); - }, - Number{}); - }; - - using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); - using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); - - template - static auto PadDescriptor_MN_2d(Desc_MN desc_mn, - index_t gridSize, - index_t blockSize, - index_t num_threads_m, - index_t num_threads_n) - { - std::ignore = blockSize; - std::ignore = gridSize; - const auto m = desc_mn.GetLength(I0); - const auto n = desc_mn.GetLength(I1); - const index_t loop_step_m = num_threads_m * MPerThread; - const index_t loop_step_n = num_threads_n * NPerThread; - const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m; - const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; - - const auto desc_mn_pad = transform_tensor_descriptor( - desc_mn, - make_tuple(make_right_pad_transform(m, pad_m), make_right_pad_transform(n, pad_n)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - return desc_mn_pad; - } - - static auto MakeDescriptor_MN(const std::array& lengths, - const std::array& stride, - index_t gridSize, - index_t blockSize, - index_t num_threads_m, - index_t num_threads_n) - { - auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); - auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); - - // nd desc - [s0, s1, s2, ...] - const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); - - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type(); - constexpr auto nDimIds = - typename arithmetic_sequence_gen::type(); - - const auto mLengths = get_container_subset(tupleOfShape, mDimIds); - const auto nLengths = get_container_subset(tupleOfShape, nDimIds); - - // merge nd to 2d desc - [s0 * s1 * ...] - - if constexpr(NumDim > 2) - { - const auto desc_mn = transform_tensor_descriptor( - desc, - make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), - make_tuple(mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return PadDescriptor_MN_2d(desc_mn, gridSize, blockSize, num_threads_m, num_threads_n); - } - else - return PadDescriptor_MN_2d(desc, gridSize, blockSize, num_threads_m, num_threads_n); - } - - template - static auto GenerateInOutGrid2dDescTuple(Number) - { - return generate_tuple( - [&](auto) { - if constexpr(NumDim > 2) - { - return MakeDescriptor_MN({1, 1}, {1, 1}, 1, 1, 1, 1); - } - else - { - return MakeDescriptor_MN({1}, {1}, 1, 1, 1, 1); - }; - }, - Number{}); - }; - - using OutGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number{})); - using InGrid2dDescTuple = decltype(GenerateInOutGrid2dDescTuple(Number{})); - - using GridwiseElementwise = GridwiseElementwise_2D; - - struct Argument : public BaseArgument - { - Argument(const std::array lengths, - const std::array, NumInput> inStridesArray, - const std::array, NumOutput> outStridesArray, - const std::array in_dev_buffers, - const std::array out_dev_buffers, - ElementwiseOperation elementwise_op) - - : lengths_(lengths), - inStridesArray_(inStridesArray), - outStridesArray_(outStridesArray), - elementwise_op_(elementwise_op), - blockSize_(256) - { - static_assert(NumDim_m > 0, ""); - static_assert(NumDim_n > 0, ""); - - in_dev_buffers_ = generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - return static_cast(in_dev_buffers[I.value]); - }, - Number{}); - - out_dev_buffers_ = generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - return static_cast(out_dev_buffers[I.value]); - }, - Number{}); - } - - InDataTypePointerTuple in_dev_buffers_; - OutDataTypePointerTuple out_dev_buffers_; - - std::array lengths_; - std::array, NumInput> inStridesArray_; - std::array, NumOutput> outStridesArray_; - - ElementwiseOperation elementwise_op_; - index_t blockSize_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - index_t gridSize = getAvailableComputeUnitCount(stream_config); - index_t num_threads_m = (gridSize * arg.blockSize_) / 16; - index_t num_threads_n = 16; - - auto in_grid_2d_desc_tuple = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(arg.lengths_, - arg.inStridesArray_[I.value], - gridSize, - arg.blockSize_, - num_threads_m, - num_threads_n); - }, - Number{}); - - auto out_grid_2d_desc_tuple = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(arg.lengths_, - arg.outStridesArray_[I.value], - gridSize, - arg.blockSize_, - num_threads_m, - num_threads_n); - }, - Number{}); - - const auto kernel = kernel_elementwise_2d; - - float elapsed_time = launch_and_time_kernel(stream_config, - kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - in_grid_2d_desc_tuple, - out_grid_2d_desc_tuple, - arg.in_dev_buffers_, - arg.out_dev_buffers_, - arg.elementwise_op_, - num_threads_m, - num_threads_n); - return elapsed_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - if(pArg == nullptr) - return false; - - if(pArg->lengths_.back() % MPerThread != 0) - return false; - - auto IsScalarPerVectorValid = [&](const std::array& lengths, - const std::array& strides, - index_t scalarPerVector, - index_t vectorDim) { - if(strides[vectorDim] == 1 && - (lengths[vectorDim] % scalarPerVector == 0 || - lengths[vectorDim] % scalarPerVector == lengths[vectorDim])) - { - return true; - } - if(strides[vectorDim] != 1 && scalarPerVector == strides[vectorDim]) - { - return true; - } - return false; - }; - - bool valid = true; - static_for<0, NumInput, 1>{}([&](auto I) { - if(!IsScalarPerVectorValid(pArg->lengths_, - pArg->inStridesArray_[I.value], - InScalarPerVectorSeq::At(I), - NumDim_m - 1)) - valid = false; - }); - - static_for<0, NumOutput, 1>{}([&](auto I) { - if(!IsScalarPerVectorValid(pArg->lengths_, - pArg->outStridesArray_[I.value], - OutScalarPerVectorSeq::At(I), - NumDim - 1)) - valid = false; - }); - - return valid; - }; - - std::unique_ptr - MakeArgumentPointer(const std::array lengths, - const std::array, NumInput> inStridesArray, - const std::array, NumOutput> outStridesArray, - const std::array in_dev_buffers, - const std::array out_dev_buffers, - ElementwiseOperation elementwise_op) override - { - return std::make_unique(lengths, - inStridesArray, - outStridesArray, - in_dev_buffers, - out_dev_buffers, - elementwise_op); - } - - static auto MakeInvoker() { return Invoker{}; } - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; -}; // namespace device - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bdc6dc998180eeb6790daa7151c25203b8f63d41 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseImpl + : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static index_t GetLowestStrideDim(const std::array& strides) + { + index_t most_continous_dim = NumDim - 1; + index_t most_continous_dim_stride = strides[most_continous_dim]; + for(index_t dim = 0; dim < NumDim; dim++) + { + if(strides[dim] < most_continous_dim_stride) + { + most_continous_dim_stride = strides[dim]; + most_continous_dim = dim; + } + } + return most_continous_dim; + } + + template + static auto PadInputOutputDescriptor(const InOutDescriptor& desc) + { + const auto M0 = desc.GetLength(I0); + const auto M1 = desc.GetLength(I1); + const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0; + const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1; + + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return padded_desc; + } + + static auto GenerateBatchDimsLenghtsTuple(const std::array& lengths, + const index_t M0_dim, + const index_t M1_dim) + { + // Generate batch dims, they will be merged to M0 + // Add one more dim than needed in case that M0 is equal to M1 + // If M0 is equal to M1, then will be one more batch dim + std::array batch_dims; + index_t batch_dim = 0; + for(index_t i = 0; i < NumDim; i++) + { + if(i != M0_dim && i != M1_dim) + { + batch_dims[batch_dim] = lengths[i]; + batch_dim++; + } + } + // Add dummy dim if M0_dim is not equal to M1_dim + if(M0_dim != M1_dim && NumDim >= 2) + batch_dims[NumDim - 2] = 1; + return generate_tuple([&](auto I) { return batch_dims[I]; }, Number{}); + } + + static auto MakeDescriptor(const std::array& lengths, + const std::array& in_strides, + const std::array& out_strides, + const std::array& desc_strides) + { + const auto M0_dim = GetLowestStrideDim(out_strides); + const auto M1_dim = GetLowestStrideDim(in_strides); + + // If M0_dim is equal to M1_dim, then make M0_dim dummy + const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim]; + const auto M1 = lengths[M1_dim]; + const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim]; + const auto M1_stride = desc_strides[M1_dim]; + + const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim); + const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim); + + const auto desc = make_naive_tensor_descriptor( + concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)), + concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride))); + // Merged batch dims with M0 + const auto transforms = + make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))), + make_pass_through_transform(M1)); + using BatchElemsSequence = + typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type; + const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence{}); + const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{}); + // desc: (merged_dims + M0, M1) + auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); + return PadInputOutputDescriptor(merged_desc); + } + + template + static auto GenerateInOutGridDescTuple() + { + std::array ones; + for(index_t d = 0; d < NumDim; d++) + { + ones[d] = 1; + } + + return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); }, + Number{}); + }; + + using InGridDescTuple = decltype(GenerateInOutGridDescTuple()); + using OutGridDescTuple = decltype(GenerateInOutGridDescTuple()); + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseOp = GridwiseElementwise; + + using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto in_grid_desc_tuple = generate_tuple( + [&](auto src_i) { + // Use Strides from first tensor to assert that M0 dim and + // M1 dim are the same for each tensor. + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.inStridesArray_[src_i]); + }, + Number{}); + + auto out_grid_desc_tuple = generate_tuple( + [&](auto dst_i) { + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.outStridesArray_[dst_i]); + }, + Number{}); + + const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + + const auto block_2_tile_map = Block2TileMap(M0, M1); + const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1); + + const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) == + GetLowestStrideDim(arg.outStridesArray_[I0]); + + const auto kernel = in_out_same_vector_dim + ? kernel_elementwise + : kernel_elementwise; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + in_grid_desc_tuple, + out_grid_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + block_2_tile_map, + arg.elementwise_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]); + const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]); + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t M_dim) { + if(scalarPerVector == 1) + { + return true; + } + if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0) + { + return true; + } + return false; + }; + + bool is_valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 && + M1PerThread % InScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 && + M1PerThread % OutScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim); + }); + + return is_valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseImpl<"; + str << NumDim << ", "; + str << BlockSize << ", "; + str << M0PerBlock << ", "; + str << M1PerBlock << ", "; + str << M0PerThread << ", "; + str << M1PerThread << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp deleted file mode 100644 index e203468c6174f0ae2c8bde33a716d0665f2e8b6c..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp +++ /dev/null @@ -1,303 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/math.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" - -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/stream_utility.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceElementwiseImpl - : public DeviceElementwise -{ - static constexpr int NumInput = InDataTypeTuple::Size(); - static constexpr int NumOutput = OutDataTypeTuple::Size(); - - static_assert(NumInput == InScalarPerVectorSeq::Size() && - NumOutput == OutScalarPerVectorSeq::Size(), - "Tuple size is inconsistent with the number of in/out!"); - - static auto GenerateInDataTypePointerTuple() - { - return generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - - return static_cast(nullptr); - }, - Number{}); - }; - - static auto GenerateOutDataTypePointerTuple() - { - return generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - - return static_cast(nullptr); - }, - Number{}); - }; - - using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); - using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); - - template - static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) - { - constexpr auto I0 = Number<0>{}; - - const auto m = desc_m.GetLength(I0); - const index_t loop_step = gridSize * blockSize * MPerThread; - const auto pad = math::integer_least_multiple(m, loop_step) - m; - const auto desc_m_pad = - transform_tensor_descriptor(desc_m, - make_tuple(make_right_pad_transform(m, pad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return desc_m_pad; - } - - static auto MakeDescriptor_M(const std::array& lengths, - const std::array& stride, - index_t gridSize, - index_t blockSize) - { - auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); - auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); - - // nd desc - [s0, s1, s2, ...] - const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); - - // merge nd to 1d desc - [s0 * s1 * ...] - if constexpr(NumDim > 1) - { - const auto desc_m = transform_tensor_descriptor( - desc, - make_tuple(make_merge_transform(tupleOfShape)), - make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), - make_tuple(Sequence<0>{})); - - return PadDescriptor_M_1d(desc_m, gridSize, blockSize); - } - else - return PadDescriptor_M_1d(desc, gridSize, blockSize); - } - - template - static auto GenerateInOutGrid1dDescTuple(Number) - { - return generate_tuple( - [&](auto) { - if constexpr(NumDim > 1) - { - return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1); - } - else - { - return MakeDescriptor_M({1}, {1}, 1, 1); - }; - }, - Number{}); - }; - - using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); - using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); - - using GridwiseElementwise = GridwiseElementwise_1D; - - struct Argument : public BaseArgument - { - Argument(const std::array lengths, - const std::array, NumInput> inStridesArray, - const std::array, NumOutput> outStridesArray, - const std::array in_dev_buffers, - const std::array out_dev_buffers, - ElementwiseOperation elementwise_op) - - : lengths_(lengths), - inStridesArray_(inStridesArray), - outStridesArray_(outStridesArray), - elementwise_op_(elementwise_op), - blockSize_(256) - { - in_dev_buffers_ = generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - return static_cast(in_dev_buffers[I.value]); - }, - Number{}); - - out_dev_buffers_ = generate_tuple( - [&](auto I) { - using DataType = remove_cvref_t; - return static_cast(out_dev_buffers[I.value]); - }, - Number{}); - } - - InDataTypePointerTuple in_dev_buffers_; - OutDataTypePointerTuple out_dev_buffers_; - - std::array lengths_; - std::array, NumInput> inStridesArray_; - std::array, NumOutput> outStridesArray_; - - ElementwiseOperation elementwise_op_; - index_t blockSize_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - index_t gridSize = getAvailableComputeUnitCount(stream_config); - - auto in_grid_1d_desc_tuple = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); - }, - Number{}); - - auto out_grid_1d_desc_tuple = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); - }, - Number{}); - - const auto kernel = kernel_elementwise_1d; - - float elapsed_time = launch_and_time_kernel(stream_config, - kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - in_grid_1d_desc_tuple, - out_grid_1d_desc_tuple, - arg.in_dev_buffers_, - arg.out_dev_buffers_, - arg.elementwise_op_); - return elapsed_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(arg.lengths_.back() % MPerThread != 0) - return false; - - auto IsScalarPerVectorValid = [&](const std::array& lengths, - const std::array& strides, - index_t scalarPerVector) { - if(strides.back() == 1 && lengths.back() % scalarPerVector == 0) - return true; - - if(strides.back() != 1 && scalarPerVector == 1) - return true; - - return false; - }; - - bool valid = true; - static_for<0, NumInput, 1>{}([&](auto I) { - if(!IsScalarPerVectorValid( - arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) - valid = false; - }); - - static_for<0, NumOutput, 1>{}([&](auto I) { - if(!IsScalarPerVectorValid( - arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) - valid = false; - }); - - return valid; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto - MakeArgument(const std::array lengths, - const std::array, NumInput> inStridesArray, - const std::array, NumOutput> outStridesArray, - const std::array in_dev_buffers, - const std::array out_dev_buffers, - ElementwiseOperation elementwise_op) - { - return Argument{lengths, - inStridesArray, - outStridesArray, - in_dev_buffers, - out_dev_buffers, - elementwise_op}; - } - - std::unique_ptr - MakeArgumentPointer(const std::array lengths, - const std::array, NumInput> inStridesArray, - const std::array, NumOutput> outStridesArray, - const std::array in_dev_buffers, - const std::array out_dev_buffers, - ElementwiseOperation elementwise_op) override - { - return std::make_unique(lengths, - inStridesArray, - outStridesArray, - in_dev_buffers, - out_dev_buffers, - elementwise_op); - } - - static auto MakeInvoker() { return Invoker{}; } - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; -}; // namespace device - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dff0530ee026abcde87d4bdc2a71d06ff61110a2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceElementwiseImpl from device_elementwise_dynamic_vector_dims_impl.hpp. + */ +template +struct DeviceElementwiseImpl : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NumDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + template + static auto GenerateInOutGrid1dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 1) + { + return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1); + } + else + { + return MakeDescriptor_M({1}, {1}, 1, 1); + }; + }, + Number{}); + }; + + using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + unary_op_(unary_op), + scale_op_(scale_op), + blockSize_(256) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + UnaryOperation unary_op_; + Scale scale_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + + auto in_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + auto out_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + const auto kernel = kernel_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_, + arg.unary_op_, + arg.scale_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector) { + if(strides.back() == 1 && lengths.back() % scalarPerVector == 0) + return true; + + if(strides.back() != 1 && scalarPerVector == 1) + return true; + + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) + valid = false; + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) + valid = false; + }); + + return valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op, + unary_op, + scale_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op, + unary_op, + scale_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseNormalizationImpl<"; + str << NumDim << ", "; + str << MPerThread << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..553143e2861e93d1784a1a99281ef8711b94e2ba --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N) +// 2. C(M, N) = A(M, K) * DequantB(K, N) + +template +struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + // LDS bypass feature not implemented for dequantization pipeline. + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0) + { + // assume Scale is [1, N] + const auto scale_grid_desc_n_k = [&]() { + const auto scale_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw); + }(); + + const auto N = scale_grid_desc_n_k.GetLength(I0); + const auto K = scale_grid_desc_n_k.GetLength(I1); + // When K = 1, it might be scale tensor. + assert(K % K1 == 0 && K != 1); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1 + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseFpAintBGemm_Wmma< + BlockSize, + ADataType, + BDataType, + ScaleDataType, + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc, + BGridDesc, + ScaleGridDesc, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const ScaleDataType* p_scale_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_scale_grid_{p_scale_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_{}, + b_grid_desc_{}, + scale_grid_desc_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); + scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const ScaleDataType* p_scale_grid_; + CDataType* p_c_grid_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + ScaleGridDesc scale_grid_desc_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_fpAintB_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + ScaleDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_scale_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.scale_grid_desc_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + printf("DeviceOp err: AccDataType"); + return false; + } + } + else + { + printf("DeviceOp err: Arch"); + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const ScaleDataType* p_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_scale, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}, + {PipelineVersion::weight_only, "weight_only"}}; + + // clang-format off + str << "DeviceFpAintBGemm_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index eedf384cd9c99fce8c32ecc54be476a036165f86..eb0fb55f5dbc8e547a7b3e0d591c423fb4f85227 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -11,7 +11,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/host_utility/device_prop.hpp" @@ -60,7 +59,6 @@ template < typename CThreadTransferSrcDstAccessOrder, index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferDstScalarPerVector, - GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default, enable_if_t< is_same_v && is_same_v && @@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm; + CThreadTransferDstScalarPerVector>; using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); @@ -276,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -402,8 +404,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -429,8 +430,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -456,8 +456,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -492,19 +491,52 @@ struct DeviceGemmDl : public DeviceGemm::value) + { + constexpr auto A_K_vec_length = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3); + if(arg.K_raw_ % A_K_vec_length != 0) + { + return false; + } + } + else + { + constexpr auto A_M_vec_lenght = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2); + if(arg.M_raw_ % A_M_vec_lenght != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + constexpr auto B_N_vec_lenght = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2); + if(arg.N_raw_ % B_N_vec_lenght != 0) + { + return false; + } + } + else { - if(ck::get_device_name() == "gfx1030") + constexpr auto B_K_vec_length = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3); + if(arg.K_raw_ % B_K_vec_length != 0) { - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + return false; } - return false; } - if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || + ck::is_gfx11_supported() || ck::is_gfx12_supported()) { return GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp deleted file mode 100644 index 336ce31e8214391183a94f5af6951852bc5ed067..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp +++ /dev/null @@ -1,133 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - GemmSpecialization GemmSpec, - index_t BlockSize, - index_t MPerBlock, - index_t NPerBlock, - index_t K0PerBlock, - index_t K1, - index_t M1PerThread, - index_t N1PerThread, - index_t KPerThread, - typename M1N1ThreadClusterM1Xs, - typename M1N1ThreadClusterN1Xs, - typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - typename ABlockTransferSrcVectorTensorContiguousDimOrder, - typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - typename BBlockTransferSrcVectorTensorContiguousDimOrder, - typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - typename CThreadTransferSrcDstAccessOrder, - index_t CThreadTransferSrcDstVectorDim, - index_t CThreadTransferDstScalarPerVector, - enable_if_t< - is_same_v && - is_same_v && - is_same_v, - bool> = false> -struct DeviceGemmDlDpp8 : public DeviceGemmDl - -{ - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmDlDpp8" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock << ", " - << K1 << ", " - << M1PerThread << ", " - << N1PerThread << ", " - << KPerThread - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..168b9ad812106810cf676c2f44ec36686ada7658 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::is_gfx103_supported() || ck::is_gfx11_supported()) + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d9dd192c9869d64eac683cb7d73ff8cf74c89fd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,751 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + AsDataType, + BsDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout_ = remove_cvref_t>; + + static_assert(is_same::value, ""); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout_ = remove_cvref_t>; + + static_assert(is_same::value, ""); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout_ = remove_cvref_t>; + + static_assert(is_same::value, ""); + }); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<( @@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return (workspace_size); }; - void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override { Argument* pArg_ = dynamic_cast(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 916f29a904eba4172b13651249263d085b16f378..bb2db930c8ea7b1ed34b17a7cde56a2c7c6daafc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -61,7 +61,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index 44b3518e2c950e15ec3d5cfd240c443029987622..35f1c77f87994852a5524547432180c274a3c113 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace tensor_operation { @@ -27,21 +28,22 @@ template struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } - else if constexpr(is_same_v) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } }(); - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } - static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same_v) + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } }(); - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template @@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD; - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, + a_grid_desc{}, + b_grid_desc{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{}, @@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}([&](auto i) { using DLayout = remove_cvref_t>; using DDataType = remove_cvref_t>; @@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t< @@ -404,68 +468,35 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD, - true>; // Last Option is W/O - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.block_2_ctile_map_); + has_main_k_block_loop>; // Last Option is W/O + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc, + arg.b_grid_desc, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< - GridwiseOp, - ADataType, - BDataType, - typename GridwiseOp::DsGridPointer, - EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - remove_reference_t, - false>; - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic @@ -484,8 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD || is_same_v)) { @@ -576,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 6502155f87aeba859f79a2321dfad7a38b4dfbf2..f5735b39cc532605cdf14dd8854e016e626d8d36 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -22,7 +22,8 @@ namespace ck { template (p_a_grid, @@ -145,7 +146,8 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeDataType = EDataType> struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad + : public DeviceGemmMultipleD +{ + static constexpr auto I1 = Number<1>{}; + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + struct Invoker : public BaseInvoker + { + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemm::AGridDesc_AK0_M_AK1, + typename GridwiseGemm::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported()) + { + return false; + } + + // Check vector load/store. + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // Check vector load of A. + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of B. + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of Ds. + // For now, only the RowMajor layout is supported. + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // Check vector load of E. + // For now, only the RowMajor layout is supported. + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7661114ef804e32e777ed642d6c67f6632ecd3ee --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,781 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + : public DeviceGemmMultipleD_ABScale +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + }; + + constexpr index_t minimum_occupancy = + (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave && + MPerBlock * NPerBlock / BlockSize > 64) + ? 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && + is_same::value) + ? false + : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && + is_same::value) + ? false + : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } - else if constexpr(is_same_v) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } }(); - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } - static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same_v) + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } }(); - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) @@ -159,56 +237,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm; + using GridwiseGemm = + GridwiseGemm_Wmma; // Argument struct Argument : public BaseArgument @@ -230,7 +310,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, remove_reference_t, - true>; // Last Option is W/O - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_gemm_wmma< - GridwiseGemm, - ADataType, - BDataType, - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic @@ -411,16 +450,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v)) + if constexpr(!(is_same_v || is_same_v || + is_same_v)) { + printf("DeviceOp err: AccDataType"); return false; } } else { + printf("DeviceOp err: Arch"); return false; } @@ -486,7 +527,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index ebbb1d38c000ba44ce53f9a8110114c41a299b14..5188ece3335d9be07e4b64c74811e23fee86c382 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -184,7 +184,7 @@ struct DeviceGemmXdl : public DeviceGemm || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index 86bc7b5bbb0640c4c2bf0e947b3adae56b358f13..dd41f4bca0a3ec7ceb27d43db553b0a6a5b265bb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -66,7 +66,8 @@ template + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + ComputeTypeA, + ComputeTypeB>; using Argument = typename GridwiseGemm::Argument; @@ -276,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm +{ + static constexpr auto I1 = Number<1>{}; + + using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemm::AGridDesc_AK0_M_AK1, + typename GridwiseGemm::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + // arg.p_ds_grid_, + ck::Tuple<>{}, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported()) + { + return false; + } + + // Check vector load/store. + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // Check vector load of A. + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of B. + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of E. + // For now, only the RowMajor layout is supported. + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + using EmptyDsPointers = std::array; + using EmptyDsStrides = std::array; + + return Argument{p_a, + p_b, + EmptyDsPointers{}, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + EmptyDsStrides{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + using EmptyDsPointers = std::array; + using EmptyDsStrides = std::array; + + return std::make_unique(p_a, + p_b, + EmptyDsPointers{}, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + EmptyDsStrides{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..452063156e2ecfa983ddc99d6896adf18505d1c2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + float ave_time = 0; + + index_t k_grain = KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + hipGetErrorString(hipMemsetAsync( + arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); + const auto Run = [&](const auto& kernel) { + dim3 grid_dim; + if(arg.Grid_size < 0) + { + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, 0); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + arg.Grid_size = num_cu * occupancy; + grid_dim = arg.Grid_size; + } + else + grid_dim = arg.Grid_size; + + if(stream_config.flush_cache) + { + Argument arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.M * arg_.K * sizeof(ADataType), + arg_.K * arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_); + } + else + { + + ave_time = launch_and_time_kernel( + stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + + return Argument{ + p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffleV2" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4489b2e5ce56bab4e6be7549bf22f500e7861a69 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -0,0 +1,736 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + ReduceDataType, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + const std::array p_ds_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + reinterpret_cast(p_c_grid_), + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + true), + p_ds(p_ds_), + StrideDs(StrideDs_) + { + } + + const std::array p_ds; + std::array StrideDs; + }; + + using ReduceAdd = ck::reduce::Add; + using OutElementwiseOperation = CElementwiseOperation; + + static constexpr auto DsVectorLengthSequence = generate_sequence_v2( + [](auto i) { + using DLayout = remove_cvref_t>; + if constexpr(std::is_same::value) + return Number{}; + else + return Number<1>{}; + }, + Number{}); + + using DeviceReduceInstance = DeviceReduceThreadWiseMultiD< + ReduceDataType, // InDataType, + DsDataType, // DsDatatype + GemmAccDataType, // AccDataType, + CDataType, // OutDataType, + 3, // Rank + 1, // NumReduceDim + ReduceAdd, + PassThrough, + OutElementwiseOperation, + 256, // BlockSize_, + CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_, + 1, // KThreadSliceSize_, + 0, // InSrcVectorDim_, + CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_, + CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_ + decltype(DsVectorLengthSequence)>; + + // Invoker + struct Invoker : public BaseInvoker + { + float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + static constexpr index_t NumInDim = 3; + static constexpr index_t NumOutDim = 2; + + std::array in_lengths = {arg.KBatch, arg.M, arg.N}; + std::array out_lengths = {arg.M, arg.N}; + + std::array in_strides; + std::array out_strides; + if constexpr(std::is_same::value) + { + in_strides = {arg.M * arg.N, arg.N, 1}; + out_strides = {arg.N, 1}; + } + else + { + in_strides = {arg.M * arg.N, 1, arg.M}; + out_strides = {1, arg.M}; + } + + std::array reduce_dims{0}; + + std::array, NumDTensor> DsLengths; + std::array, NumDTensor> DsStrides; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + DsLengths[i] = out_lengths; + + using DLayout = remove_cvref_t>; + if constexpr(std::is_same::value) + { + DsStrides[i] = {arg.StrideDs[i], 1}; + } + else + { + DsStrides[i] = {1, arg.StrideDs[i]}; + } + }); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(in_lengths, + in_strides, + DsLengths, + DsStrides, + out_lengths, + out_strides, + reduce_dims, + arg.p_workspace_, + arg.p_ds, + arg.p_c_grid, + PassThrough{}, + OutElementwiseOperation{}); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float ave_time = 0; + + if(reduce.IsSupportedArgument(argument_ptr.get())) + { + ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); + } + else + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + } + + return ave_time; + } + + float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{}) + { + auto arg = *dynamic_cast(&arg_); + + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && + std::is_same::value)) + { + if(arg.p_workspace_ == nullptr) + { + throw std::runtime_error("using reduce , but empty workspace!"); + } + + arg.p_c_grid = reinterpret_cast(arg.p_workspace_); + } + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + ck::utility::RotatingMemWrapper rotating_mem( + arg, + stream_config.rotating_count, + arg.M * arg.K * sizeof(ADataType), + arg.K * arg.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg); + } + else + { + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && + std::is_same::value)) + { + // reduce c data + ave_time += RunReduce(arg_, stream_config); + } + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversalReduce" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_arg); + + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && + std::is_same::value)) + { + std::cout << "using workspace" << std::endl; + return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType); + } + + return 0; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index a8cd39080cabc3fac597ae481a8d72d668652202..fd6f3b65f2f5239d260accd0be61c2fd9e13dc01 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " @@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index d54ddf433e49f96c7617de50bb016d427974f4f0..c2a27ebbdb849b8210b2f0d83c41d348ad383f00 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + PipelineVersion PipelineVer = PipelineVersion::v1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename LDSTypeA = ComputeType, + typename LDSTypeB = ComputeType> struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + CElementwiseOperation, + ComputeType> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -78,7 +82,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; - using Argument = typename GridwiseGemm::Argument; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; // Invoker @@ -154,9 +206,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); }; if(has_main_k0_block_loop) @@ -179,7 +240,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -189,7 +253,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -202,7 +269,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -212,7 +282,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -260,12 +333,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(static_cast(p_a), @@ -310,8 +386,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] + << ", PipelineVersion: " << PipelineVersionToString[PipelineVer]; + + return str.str(); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd33e577bfdd081312b67c1db5a9d9b176c7a0bf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -0,0 +1,423 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + +struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeType>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Invoker + struct Invoker : public BaseInvoker + { + + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + + const auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0Padded = karg.K0Padded; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp index 8de42ba9efcdc018d0d945ace9aa3e2da9a41cbf..51b8958d61fe99a67ce1a9815c8b1495c6a8967a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(pArg); @@ -241,9 +243,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 0b1bf0f1792cec386ffe3a2cdf1444d18f14e4cc..cc022b89c5ed472758300a1cd4e19cc5b2924f09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -38,7 +38,7 @@ __global__ void const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3fb047f207b723563d50951b61a0d18c31d2faf0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,882 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + static constexpr auto transform_conv_to_gemm = + TransformConvBwdDataToGemm_v1{}; + + static auto GetDummyABDsEGridDescriptor() + { + const std::array dummy_tensor_lengths = {1}; + const std::array dummy_tensor_strides = {1}; + const std::array dummy_spatial_lengths = {1}; + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return transform_conv_to_gemm.template MakeCDescriptor_M_N( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + }, + Number{}); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // desc + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + true, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, + ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + // for check validity + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap( + e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_ctile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_k_wos_lengths_; + std::array a_g_n_k_wos_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_c_wis_lengths_; + std::array, NumDTensor> ds_g_n_c_wis_strides_; + std::array e_g_n_c_wis_lengths_; + std::array e_g_n_c_wis_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_ctile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + + // check number of dimension, only implemented for 2D and 3D now + if(NDimSpatial != 2 && NDimSpatial != 3) + { + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 8a0a4453702a081c02027653e3c13fa61b8e8c46..b544c925e1394ac05012ad37e09d4611e6da2606 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -24,51 +25,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -131,17 +87,17 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -242,7 +198,9 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD + CDEElementwiseOp, + AComputeType, + BComputeType> { - // FIXME + // TODO: Extend support for more spatial dimensions. static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); @@ -265,7 +225,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr index_t NumDTensor = DsDataType::Size(); - // TODO make A/B datatype different + // TODO: Add support for different A and B data types. using ABDataType = ADataType; static constexpr auto I0 = Number<0>{}; @@ -280,6 +240,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 BK1, MPerBlock, NPerBlock, + KPerBlock, DoPadGemmM, DoPadGemmN>{}; @@ -355,7 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ABDataType, // TODO: distinguish A/B datatype + ABDataType, + ABDataType, + AComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -395,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; + LoopSched, + PipelineVersion::v1, + BComputeType>; template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) @@ -712,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::vector block_2_etile_map_container_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOp a_element_op_; @@ -781,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b3f1a0255671b7cdf8c6fb283bdb4e3a3314912 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -0,0 +1,1234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_dlops_bwd_weight( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, + const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_kbatch_k0_m0_m1_k1, + b_grid_desc_kbatch_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_kbatch_k0_m0_m1_k1; + ignore = b_grid_desc_kbatch_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Dl; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto spatial_offset = I3; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) + { + using namespace ck; + + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset]; + const index_t InLeftPadW = input_left_pads[I0]; + const index_t InRightPadW = input_right_pads[I0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + const index_t ConvDilationW = conv_filter_dilations[I0]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wi, C), make_tuple(InWStride, InCStride)); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weights tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride)); + + // A: output tensor + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + + } // function end + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) + { + using namespace ck; + + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1]; + + const index_t InLeftPadH = input_left_pads[I0]; + const index_t InLeftPadW = input_left_pads[I1]; + const index_t InRightPadH = input_right_pads[I0]; + const index_t InRightPadW = input_right_pads[I1]; + const index_t ConvStrideH = conv_filter_strides[I0]; + const index_t ConvStrideW = conv_filter_strides[I1]; + const index_t ConvDilationH = conv_filter_dilations[I0]; + const index_t ConvDilationW = conv_filter_dilations[I1]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride)); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride)); + + // A: output tensor + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + + } // function end + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) + { + using namespace ck; + + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2]; + const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2]; + const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2]; + + const index_t InLeftPadD = input_left_pads[I0]; + const index_t InLeftPadH = input_left_pads[I1]; + const index_t InLeftPadW = input_left_pads[I2]; + const index_t InRightPadD = input_right_pads[I0]; + const index_t InRightPadH = input_right_pads[I1]; + const index_t InRightPadW = input_right_pads[I2]; + const index_t ConvStrideD = conv_filter_strides[I0]; + const index_t ConvStrideH = conv_filter_strides[I1]; + const index_t ConvStrideW = conv_filter_strides[I2]; + const index_t ConvDilationD = conv_filter_dilations[I0]; + const index_t ConvDilationH = conv_filter_dilations[I1]; + const index_t ConvDilationW = conv_filter_dilations[I2]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InDStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride)); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride)); + + // A: output tensor + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_B_K0_M_K1 = remove_cvref_t; + using BGridDesc_B_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = + GridwiseGemmDl_bkm_bkn_mn_v1r3; + + // Argument + using AGridDesc_B_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})); + using BGridDesc_B_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_G_{a_g_n_c_wis_lengths[I0]}, + Conv_K_{b_g_k_c_xs_lengths[I1]}, + Conv_C_{a_g_n_c_wis_lengths[I2]}, + filter_lengths_{b_g_k_c_xs_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + a_grid_desc_kbatch_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_); + b_grid_desc_kbatch_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + ck::index_t M01 = 1; + ck::index_t N01 = 1; + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0]; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_; + BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + // DefaultBlock2CTileMap block_2_ctile_map_; + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + + // element-wise op + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_K_; + const index_t Conv_C_; + + std::array filter_lengths_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = kernel_batched_gemm_dlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + has_main_loop, + has_double_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m0_m1_k1_, + arg.b_grid_desc_kbatch_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + + // DL version only supports split_k equal to 1 + if(arg.k_batch_ != 1) + return false; + + if constexpr(!((NDimSpatial == 1 && + (is_NWGC_GKXC_NWGK() || + is_GNWC_GKXC_GNWK())) || + (NDimSpatial == 2 && + (is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) || + (NDimSpatial == 3 && + (is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())))) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_lengths_[spatial_offset + i] == 1 && + arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // matrix A + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I2] != 1 || srcVectorLengths[I3] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I4] != 0 || K0PerBlock % srcVectorLengths[I1] != 0) + { + return false; + } + + const index_t K = arg.Conv_K_; + + if(K % (srcVectorLengths[I1] * srcVectorLengths[I4]) != 0) + { + return false; + } + } + + // matrix B + { + auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{}; + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I4] != 1) + { + return false; + } + if(srcLoadLenghts[I2] % srcVectorLengths[I2] != 0 || + srcLoadLenghts[I3] % srcVectorLengths[I3] != 0) + { + return false; + } + + const index_t C = arg.Conv_K_; + + if(C % (srcVectorLengths[I2] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = " + << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl; + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp deleted file mode 100644 index 198751cdf350f7f9d828edda641b8ef217e98af6..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ /dev/null @@ -1,1226 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_dlops_bwd_weight( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, - const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) -{ - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - - __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; - - GridwiseGemm::template Run( - p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_grid_desc_kbatch_k0_m0_m1_k1, - b_grid_desc_kbatch_k0_n0_n1_k1, - c_grid_desc_m0_m10_m11_n0_n10_n11, - block_2_ctile_map, - integral_constant{}, - integral_constant{}); -} - -template -struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl - : public DeviceGroupedConvBwdWeight< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - InDataType, - WeiDataType, - OutDataType, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> -{ - using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; - - using ADataType = OutDataType; - using BDataType = InDataType; - using CDataType = WeiDataType; - - using AElementwiseOperation = OutElementwiseOperation; - using BElementwiseOperation = InElementwiseOperation; - using CElementwiseOperation = WeiElementwiseOperation; - - // TODO make A/B datatype different - using ABDataType = InDataType; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - - static constexpr auto K1Number = Number{}; - static constexpr auto GemmK1Number = K1Number; - - // Bytes per 32 lds bank: 32 * 4 bytes - static constexpr auto BankLength = 128; - static constexpr auto ElePerBank = BankLength / sizeof(ADataType); - - // M1 & M0 - static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; - static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; - static constexpr auto ABlockLdsM1Padding = 4; - - // N1 & N0 - static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; - static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; - static constexpr auto BBlockLdsN1Padding = 4; - - template ::type = false> - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const ck::index_t batch_k) - { - using namespace ck; - - const index_t Wi = input_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - const index_t ConvStrideW = conv_filter_strides[0]; - const index_t ConvDilationW = conv_filter_dilations[0]; - - const index_t GemmKTotal = N * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X; - - const index_t GemmKBatch = batch_k; - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * - K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weights tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - else - { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_n_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(X, C)), - make_merge_transform(make_tuple(N, Wo))), - make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - - } // function end - template ::type = false> - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const ck::index_t batch_k) - { - using namespace ck; - - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X * Y; - - const index_t GemmKBatch = batch_k; - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * - K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - else - { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - - } // function end - - template ::type = false> - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const ck::index_t batch_k) - { - using namespace ck; - - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; - - const index_t Do = output_spatial_lengths[0]; - const index_t Ho = output_spatial_lengths[1]; - const index_t Wo = output_spatial_lengths[2]; - - const index_t Z = filter_spatial_lengths[0]; - const index_t Y = filter_spatial_lengths[1]; - const index_t X = filter_spatial_lengths[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * Z * X * Y; - - const index_t GemmKBatch = batch_k; - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * - K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - else - { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_dip_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), - make_merge_transform(make_tuple(N, Do, Ho, Wo))), - make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - - } // function end - - template ::type = false> - static auto GetABCGridDesc() - { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); - } - - template ::type = false> - static auto GetABCGridDesc() - { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); - } - - template ::type = false> - static auto GetABCGridDesc() - { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, - 1, - 1, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - 1); - } - - using ABCGridDescs = decltype(GetABCGridDesc()); - - using AGridDesc_B_K0_M_K1 = remove_cvref_t; - using BGridDesc_B_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - - using GridwiseGemm = - GridwiseGemmDl_bkm_bkn_mn_v1r3; - - // Argument - using AGridDesc_B_K0_M0_M1_K1 = - decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})); - using BGridDesc_B_K0_N0_N1_K1 = - decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})); - using CGridDesc_M0_M10_M11_N0_N10_N11 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); - using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); - - struct Argument : public BaseArgument - { - Argument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - const std::array& a_g_n_c_wis_lengths, // input - const std::array& /*a_g_n_c_wis_strides*/, - const std::array& b_g_k_c_xs_lengths, // weight - const std::array& /*b_g_k_c_xs_strides*/, - const std::array& e_g_n_k_wos_lengths, // output - const std::array& /*e_g_n_k_wos_strides*/, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) - : p_a_grid_{p_out_grid}, - p_b_grid_{p_in_grid}, - p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, - c_grid_desc_m_n_{}, - block_2_ctile_map_{}, - compute_ptr_offset_of_batch_{}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, - Conv_G_{a_g_n_c_wis_lengths[0]}, - Conv_N_{a_g_n_c_wis_lengths[1]}, - Conv_K_{b_g_k_c_xs_lengths[1]}, - Conv_C_{a_g_n_c_wis_lengths[2]}, - input_spatial_lengths_{}, - filter_spatial_lengths_{}, - output_spatial_lengths_{}, - conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, - input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads}, - k_batch_{split_k} - { - constexpr index_t spatial_offset = 3; - std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, - end(a_g_n_c_wis_lengths), - begin(input_spatial_lengths_)); - std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, - end(b_g_k_c_xs_lengths), - begin(filter_spatial_lengths_)); - std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, - end(e_g_n_k_wos_lengths), - begin(output_spatial_lengths_)); - - const auto descs = - DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - k_batch_); - - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; - - a_grid_desc_kbatch_k0_m0_m1_k1_ = - GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_); - b_grid_desc_kbatch_k0_n0_n1_k1_ = - GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_); - c_grid_desc_m0_m10_m11_n0_n10_n11_ = - GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); - ck::index_t M01 = 1; - ck::index_t N01 = 1; - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); - - // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = - Conv_N_ * Conv_K_ * - std::accumulate(begin(output_spatial_lengths_), - end(output_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideB_ = - Conv_N_ * Conv_C_ * - std::accumulate(begin(input_spatial_lengths_), - end(input_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideC_ = - Conv_K_ * Conv_C_ * - std::accumulate(begin(filter_spatial_lengths_), - end(filter_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - } - - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - - AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - - AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_; - BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_; - CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; - - // DefaultBlock2CTileMap block_2_ctile_map_; - Block2CTileMap block_2_ctile_map_; - - // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - - // element-wise op - OutElementwiseOperation a_element_op_; - WeiElementwiseOperation b_element_op_; - InElementwiseOperation c_element_op_; - - // for checking IsSupportedArgument() - const index_t Conv_G_; - const index_t Conv_N_; - const index_t Conv_K_; - const index_t Conv_C_; - - std::array input_spatial_lengths_; - std::array filter_spatial_lengths_; - std::array output_spatial_lengths_; - const std::array& conv_filter_strides_; - const std::array& conv_filter_dilations_; - const std::array& input_left_pads_; - const std::array& input_right_pads_; - index_t k_batch_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - void ShowInfo(const Argument& arg) - { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - - ShowInfo(arg); - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting"); - } - - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; - - auto launch_kernel = [&](auto has_main_k_block_loop, - auto has_double_tail_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - constexpr bool has_double_loop = has_double_tail_k_block_loop.value; - - const auto kernel = kernel_batched_gemm_dlops_bwd_weight< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - ComputePtrOffsetOfStridedBatch, - has_main_loop, - has_double_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Conv_G_, - arg.a_grid_desc_kbatch_k0_m0_m1_k1_, - arg.b_grid_desc_kbatch_k0_n0_n1_k1_, - arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, - arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); - }; - - const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1); - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); - const bool has_double_tail_k_block_loop = - GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - return launch_kernel(integral_constant{}, - integral_constant{}); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - return launch_kernel(integral_constant{}, - integral_constant{}); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - return launch_kernel(integral_constant{}, - integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}, - integral_constant{}); - } - } - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) - { - return false; - } - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // check if it's 1x1, stride=1 pad = 0 conv - for(int i = 0; i < NDimSpatial; i++) - { - if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && - arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) - { - return false; - } - } - } - - // matrix A - { - auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; - if(srcVectorLengths[I2] != 1 || srcVectorLengths[I3] != 1) - { - return false; - } - if(K1 % srcVectorLengths[I4] != 0 || K0PerBlock % srcVectorLengths[I1] != 0) - { - return false; - } - - const index_t K = arg.Conv_K_; - - if(K % (srcVectorLengths[I1] * srcVectorLengths[I4]) != 0) - { - return false; - } - } - - // matrix B - { - auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{}; - auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; - if(srcVectorLengths[I1] != 1 || srcVectorLengths[I4] != 1) - { - return false; - } - if(srcLoadLenghts[I2] % srcVectorLengths[I2] != 0 || - srcLoadLenghts[I3] % srcVectorLengths[I3] != 0) - { - return false; - } - - const index_t C = arg.Conv_K_; - - if(C % (srcVectorLengths[I2] * srcVectorLengths[I3]) != 0) - { - return false; - } - } - - // vector store C matrix into global memory - if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) - { - std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = " - << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl; - return false; - } - - // Gridwise GEMM size - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_); - } - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto - MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - const std::array& a_g_n_c_wis_lengths, // input - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, // weight - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, // output - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) - { - return Argument{p_in_grid, - p_wei_grid, - p_out_grid, - a_g_n_c_wis_lengths, // input - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, // weight - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, // output - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op, - split_k}; - } - - static auto MakeInvoker() { return Invoker{}; } - - std::unique_ptr - MakeArgumentPointer(const void* p_in_grid, - void* p_wei_grid, - const void* p_out_grid, - const std::array& a_g_n_c_wis_lengths, // input - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, // weight - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, // output - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) override - { - return std::make_unique(static_cast(p_in_grid), - static_cast(p_wei_grid), - static_cast(p_out_grid), - a_g_n_c_wis_lengths, // input - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, // weight - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, // output - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op, - split_k); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock << ", " - << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " - << K1 - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a7df1c9d570c81af44f226709d1334f7f254810c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1092 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_bwd_weight( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = batch_count; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle + : public DeviceGroupedConvBwdWeightMultipleD +{ + using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + static constexpr index_t NumDTensor = DsLayout::Size(); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemm{}; + + static constexpr index_t MaxScalarPerVectorFP32 = 4; + static constexpr index_t WorkspaceInOutScalarPerVector = + is_same_v + ? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32) + : CBlockTransferScalarPerVector_NWaveNPerXdl; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, + BDataType, + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + element_wise::PassThrough, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + WorkspaceInOutScalarPerVector, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&](auto) constexpr { return Number{}; }, + Number{}); + } + + static constexpr auto GetDsGridPointerTuple() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return static_cast(nullptr); + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t X = ds_g_k_c_xs_lengths[i][I3]; + const index_t CStride = ds_g_k_c_xs_strides[I2]; + const index_t KStride = ds_g_k_c_xs_strides[I1]; + + const auto wei_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(KStride, CStride)); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X; + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t Y = ds_g_k_c_xs_lengths[i][I3]; + const index_t X = ds_g_k_c_xs_lengths[i][I4]; + + const auto wei_grid_desc = + conv_to_gemm_transformer.template make_wei_grid_desc( + K, Y, X, C, ds_g_k_c_xs_strides[i]); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t Z = ds_g_k_c_xs_lengths[i][I3]; + const index_t Y = ds_g_k_c_xs_lengths[i][I4]; + const index_t X = ds_g_k_c_xs_lengths[i][I5]; + + const auto wei_grid_desc = + conv_to_gemm_transformer.template make_wei_grid_desc( + K, Z, Y, X, C, ds_g_k_c_xs_strides[i]); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X * Y * Z; + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template + static void + InitElementwiseBatchStrides(const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch_, + std::array& input_batch_strides, + std::array& output_batch_strides) + { + input_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_; + output_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_; + + // input_batch_strides = {C, Ds...} + static_for<0, NumDTensor, 1>{}([&](auto i) { + input_batch_strides[i + 1] = compute_ptr_offset_of_batch_.BatchStrideDs_[i]; + }); + } + + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {})); + using CDGridDesc_M_N = decltype(concat_tuple(Tuple{}, DsGridDesc_M_N{})); + using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); + using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); + using EGridDesc_M_N = CGridDesc_M_N; + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + decltype(MakeElementwiseInputSequence()), + Sequence, + I1, + I1>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument( + const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_ds_grid_{}, + p_e_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + static_assert(is_same_v, "Not supported D data layout"); + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; + }); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ds_grid_descs_tuple_ = + MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ + ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + ce_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_); + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + DsGridPointerTuple p_ds_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + DsGridDesc_M_N ds_grid_descs_tuple_; + + Block2CTileMap block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; + + constexpr bool has_main_loop = has_main_k_block_loop.value; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + AccDataType, + OutElementwiseOperation, + InElementwiseOperation, + element_wise::PassThrough, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + p_c_grid, + arg.a_element_op_, + arg.b_element_op_, + element_wise::PassThrough{}, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * + arg.Conv_G_; + + std::array input_batch_strides; + std::array output_batch_strides; + InitElementwiseBatchStrides( + arg.compute_ptr_offset_of_batch_, input_batch_strides, output_batch_strides); + + const auto kernel = kernel_batched_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + NumDTensor + I1, + I1>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.ce_grid_desc_m_n_), arg.ds_grid_descs_tuple_), + make_tuple(arg.ce_grid_desc_m_n_), + concat_tuple(make_tuple(p_c_grid), arg.p_ds_grid_), + arg.p_e_grid_, + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + input_batch_strides, + output_batch_strides); + }; + + float avg_time = 0; + if(has_main_k0_block_loop) + { + avg_time = launch_gemm_kernel(integral_constant{}); + } + else + { + avg_time = launch_gemm_kernel(integral_constant{}); + } + + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 && + arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_ds, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + ds_g_k_c_xs_lengths, + ds_g_k_c_xs_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + p_ds, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + ds_g_k_c_xs_lengths, + ds_g_k_c_xs_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c1f58ccda5138b1405aabeb8c74e22e5daad57d2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,1795 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + // If NGCHW then ADataType must be equal to BDataType + static_assert(!(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) || + is_same_v); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer_v2 = + TransformConvBwdWeightToGemmV2{}; + + static constexpr auto conv_to_gemm_transformer_v1 = + TransformConvBwdWeightToGemm{}; + + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using CElementwiseGridDesc_M_N = + remove_cvref_t())>; + + using GridwiseGemm = + GridwiseGemm_xdl_cshuffle_v3; + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseCast = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + using GridwiseElementwiseTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_e_grid_{p_wei_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array a_g_n_k_wos_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); + + const auto descs = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ce_elementwise_grid_desc_m_n_ = + conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_)[I2]; + + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ + ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + } + + std::size_t GetWorkspaceSizeBytes() const + { + // Transpose require workspace for A and B + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + else + { + return GetWorkspaceETensorSizeBytes(); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2TileMapElementwise elementwise_block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType); + p_b_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / + sizeof(BDataType); + } + + // nullptr for output, will be set after workspace set + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid, + 0, + arg.GetWorkspaceETensorSizeBytes(), + stream_config.stream_id_)); + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; + + std::array in_out_batch_strides = { + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; + + const auto kernel = kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + I1, + I1>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + in_out_batch_strides, + in_out_batch_strides); + }; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + const index_t grid_size_a = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t grid_size_b = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / + sizeof(BDataType); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_a + grid_size_b), + dim3(BlockSize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + grid_size_a); + } + + avg_time += RunGemmV3(arg, stream_config); + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + std::cerr << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_NGCHW_GKYXC_NGKHW())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_NGCDHW_GKZYXC_NGKDHW())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + if constexpr(NumGroupsToMerge > 1) + { + // support only if whole M and N can be proccessed on one block + if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1)) + { + return false; + } + if(arg.Conv_G_ % NumGroupsToMerge != 0) + { + return false; + } + } + + if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) + { + if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0) + { + return false; + } + + if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + return false; + } + + if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumGroupsToMerge; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) { + str << ", TransposeTransferSrcScalarPerVector: " + << TransposeTransferSrcScalarPerVector <<", " + << "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector; + } + + + str << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0831b754c8a4fd0428961968ebec14e7d7b7cb84 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,871 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template ::type = false> +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto GemmK1Number = Number{}; + static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; + const index_t GemmKPad = GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Pad + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto GetABCGridDesc() + { + const index_t dim = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using CShuffleDataType = AccDataType; + + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + Tuple<>, + CGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + true, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + true, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})); + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( + CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap( + c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void Print(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(arg); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + has_main_loop>; + + using EmptyTuple = Tuple<>; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + EmptyTuple{}, // Ds + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{}, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // TODO: Add support for split_k > 1 + if(arg.k_batch_ != 1) + { + return false; + } + + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 71c528e4c92e339b334d78faab82fe10c7bf914e..996107343d510f6b0c5e418881fa90ea5b60e4b2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,8 +12,10 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -21,34 +23,9 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); - __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, @@ -163,7 +137,9 @@ template + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight_Xdl_CShuffle : public DeviceGroupedConvBwdWeight + OutElementwiseOperation, + ComputeTypeA, + ComputeTypeB> { using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; @@ -189,30 +167,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle // TODO make A/B datatype different using ABDataType = InDataType; - // 1d - static constexpr bool is_GNWK_GKXC_GNWC = - is_same_v && - is_same_v && - is_same_v; - // 2d - static constexpr bool is_NHWGK_GKYXC_NHWGC = - is_same_v && - is_same_v && - is_same_v; - static constexpr bool is_GNHWK_GKYXC_GNHWC = - is_same_v && - is_same_v && - is_same_v; - // 3d - static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = - is_same_v && - is_same_v && - is_same_v; - static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = - is_same_v && - is_same_v && - is_same_v; - static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -220,8 +174,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; - static constexpr auto K1Number = Number{}; - static constexpr auto GemmK1Number = K1Number; + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemm{}; // Bytes per 32 lds bank: 32 * 4 bytes static constexpr auto BankLength = 128; @@ -237,690 +198,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; static constexpr auto BBlockLdsN1Padding = 4; - template ::type = false> - constexpr static auto - make_out_grid_desc(const ck::index_t N, - const ck::index_t Ho, - const ck::index_t Wo, - const ck::index_t K, - const std::array& output_strides) - { - const index_t WoStride = output_strides[4]; - const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), - make_tuple(WoStride, KStride)); - } - - template ::type = false> - constexpr static auto - make_in_grid_desc(const ck::index_t N, - const ck::index_t Hi, - const ck::index_t Wi, - const ck::index_t C, - const std::array& input_strides) - { - const index_t NStride = input_strides[1]; - const index_t HiStride = input_strides[3]; - const index_t WiStride = input_strides[4]; - const auto CStride = input_strides[2]; - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), - make_tuple(WiStride, CStride)); - } - else - { - return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), - make_tuple(NStride, HiStride, WiStride, CStride)); - } - } - - template ::type = false> - constexpr static auto - make_wei_grid_desc(const ck::index_t K, - const ck::index_t Y, - const ck::index_t X, - const ck::index_t C, - const std::array& weights_strides) - { - const auto CStride = Number<1>{}; - const auto KStride = weights_strides[1]; - return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); - } - - template ::type = false> - constexpr static auto - make_out_grid_desc(const ck::index_t N, - const ck::index_t Do, - const ck::index_t Ho, - const ck::index_t Wo, - const ck::index_t K, - const std::array& output_strides) - { - const index_t WoStride = output_strides[5]; - const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), - make_tuple(WoStride, KStride)); - } - - template ::type = false> - constexpr static auto - make_in_grid_desc(const ck::index_t N, - const ck::index_t Di, - const ck::index_t Hi, - const ck::index_t Wi, - const ck::index_t C, - const std::array& input_strides) - { - const index_t NStride = input_strides[1]; - const index_t DiStride = input_strides[3]; - const index_t HiStride = input_strides[4]; - const index_t WiStride = input_strides[5]; - const auto CStride = input_strides[2]; - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), - make_tuple(WiStride, CStride)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - } - } - - template ::type = false> - constexpr static auto - make_wei_grid_desc(const ck::index_t K, - const ck::index_t Z, - const ck::index_t Y, - const ck::index_t X, - const ck::index_t C, - const std::array& weights_strides) - { - const auto CStride = Number<1>{}; - const auto KStride = weights_strides[1]; - return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), - make_tuple(KStride, CStride)); - } - - template ::type = false> - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& /* input_strides */, - const std::array& /* weights_strides */, - const std::array& /* output_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const ck::index_t batch_k) - { - using namespace ck; - - const index_t Wi = input_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[0]; - const index_t ConvStrideW = conv_filter_strides[0]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - const index_t GemmKTotal = N * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X; - - const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; - const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; - - const index_t GemmKBatch = batch_k; - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * - K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); - } - else - { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_n_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(X, C)), - make_merge_transform(make_tuple(N, Wo))), - make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); - - // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto wei_gemmm_gemmn_pad_grid_desc = - transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmM, PadGemmM), - make_right_pad_transform(GemmN, PadGemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, - wei_gemmm_gemmn_pad_grid_desc); - } - } - - template ::type = false> - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& weights_strides, - const std::array& output_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const ck::index_t batch_k) - { - using namespace ck; - - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X * Y; - - const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; - const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; - - const index_t GemmKBatch = batch_k; - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * - K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - - const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); - const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); - const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); - } - else - { - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto wei_gemmm_gemmn_pad_grid_desc = - transform_tensor_descriptor(wei_grid_desc, - make_tuple(make_right_pad_transform(GemmM, PadGemmM), - make_right_pad_transform(GemmN, PadGemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, - wei_gemmm_gemmn_pad_grid_desc); - } - } - - template ::type = false> - static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& weights_strides, - const std::array& output_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const ck::index_t batch_k) - { - using namespace ck; - - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; - - const index_t Do = output_spatial_lengths[0]; - const index_t Ho = output_spatial_lengths[1]; - const index_t Wo = output_spatial_lengths[2]; - - const index_t Z = filter_spatial_lengths[0]; - const index_t Y = filter_spatial_lengths[1]; - const index_t X = filter_spatial_lengths[2]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * Z * X * Y; - - const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; - const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; - - const index_t GemmKBatch = batch_k; - const index_t GemmK0 = - math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * - K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - - const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); - const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); - const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); - - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_grid_desc); - } - else - { - // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_dip_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), - make_merge_transform(make_tuple(N, Do, Ho, Wo))), - make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // Padd - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmM, PadGemmM), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = - transform_tensor_descriptor( - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmKBatch), - make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmN, PadGemmN), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto wei_gemmm_gemmn_pad_grid_desc = - transform_tensor_descriptor(wei_grid_desc, - make_tuple(make_right_pad_transform(GemmM, PadGemmM), - make_right_pad_transform(GemmN, PadGemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, - wei_gemmm_gemmn_pad_grid_desc); - } - } // function end - template ::type = false> static auto GetABCGridDesc() { @@ -929,20 +206,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array lengths{1}; const std::array strides{1, 1, 1, 1}; const std::array params{1}; - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> @@ -953,20 +231,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array lengths{1, 1}; const std::array strides{1, 1, 1, 1, 1}; const std::array params{1, 1}; - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> @@ -977,66 +256,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array lengths{1, 1, 1}; const std::array strides{1, 1, 1, 1, 1, 1}; const std::array params{1, 1, 1}; - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); - } - - // type convert descs - template - static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) - { - const auto m0 = desc_m0.GetLength(I0); - const index_t loop_step = gridSize * blockSize * 4; - const auto pad = math::integer_least_multiple(m0, loop_step) - m0; - const auto desc_m0_pad = - transform_tensor_descriptor(desc_m0, - make_tuple(make_right_pad_transform(m0, pad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return desc_m0_pad; - } - - template - static auto MakeDescriptor_M0(const std::array& shape, - const std::array& stride, - index_t gridSize, - index_t blockSize) - { - auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); - auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); - - // nd desc - [s0, s1, s2, ...] - const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); - - // merge nd to 1d desc - [s0 * s1 * ...] - if constexpr(Dim > 1) - { - const auto desc_m0 = transform_tensor_descriptor( - desc, - make_tuple(make_merge_transform(tupleOfShape)), - make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), - make_tuple(Sequence<0>{})); - - return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); - } - else - return PadDescriptor_M0_1d(desc, gridSize, blockSize); + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } - using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); - using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; @@ -1045,7 +281,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, @@ -1090,7 +327,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, - true>; + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -1104,12 +345,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - const std::array& a_g_n_c_wis_lengths, // input - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, // weight - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, // output - const std::array& e_g_n_k_wos_strides, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -1134,10 +375,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle a_element_op_{out_element_op}, b_element_op_{in_element_op}, c_element_op_{wei_element_op}, - Conv_G_{a_g_n_c_wis_lengths[0]}, - Conv_N_{a_g_n_c_wis_lengths[1]}, - Conv_K_{b_g_k_c_xs_lengths[1]}, - Conv_C_{a_g_n_c_wis_lengths[2]}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, input_spatial_lengths_{}, filter_spatial_lengths_{}, output_spatial_lengths_{}, @@ -1147,32 +388,33 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle k_batch_{split_k} { constexpr index_t spatial_offset = 3; - std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, - end(a_g_n_c_wis_lengths), + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), begin(input_spatial_lengths_)); - std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, - end(b_g_k_c_xs_lengths), + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); - std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, - end(e_g_n_k_wos_lengths), + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); const auto descs = - DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, - a_g_n_c_wis_strides, - b_g_k_c_xs_strides, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - k_batch_); + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; @@ -1182,8 +424,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = Conv_K_ * Conv_C_ * std::accumulate(begin(filter_spatial_lengths_), @@ -1212,13 +454,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; index_t M01_; index_t N01_; - InElementwiseOperation a_element_op_; - OutElementwiseOperation b_element_op_; + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; WeiElementwiseOperation c_element_op_; // for checking IsSupportedArgument() @@ -1281,7 +523,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype + ADataType, + BDataType, CDataType, OutElementwiseOperation, InElementwiseOperation, @@ -1290,7 +533,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, has_main_loop>; return launch_and_time_kernel(stream_config, @@ -1337,23 +580,29 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } if constexpr(NDimSpatial == 1) { - if constexpr(!is_GNWK_GKXC_GNWC) + if constexpr(!is_GNWC_GKXC_GNWK()) { return false; } } else if constexpr(NDimSpatial == 2) { - if constexpr(!(is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) { return false; } } else if constexpr(NDimSpatial == 3) { - if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) { return false; } @@ -1407,12 +656,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MakeArgument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - const std::array& a_g_n_c_wis_lengths, // input - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, // weight - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, // output - const std::array& e_g_n_k_wos_strides, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -1425,12 +674,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return Argument{p_in_grid, p_wei_grid, p_out_grid, - a_g_n_c_wis_lengths, // input - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, // weight - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, // output - e_g_n_k_wos_strides, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1449,12 +698,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - const std::array& a_g_n_c_wis_lengths, // input - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, // weight - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, // output - const std::array& e_g_n_k_wos_strides, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -1467,12 +716,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - a_g_n_c_wis_lengths, // input - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, // weight - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, // output - e_g_n_k_wos_strides, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 8b22bd209043e900cd594bcac7923833e9a26536..3e14f66a09c992259e723c5ebccf9bc672e83f9f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,10 +15,11 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -29,51 +30,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -134,19 +90,19 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t c_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -260,18 +216,18 @@ template struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK - : public DeviceGroupedConvFwdMultipleD + : public DeviceGroupedConvFwdMultipleABD { using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK; @@ -282,36 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -329,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK template static auto - MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -352,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -366,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK return out_gemmm_gemmn_desc; } - static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); }, Number{}); } // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using AGridDesc_AK0_M_AK1 = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // GridwiseGemm using GridwiseGemm = @@ -469,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, a_grid_desc_ak0_m_ak1_{ - DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1( - b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_k0_m0_m1_k1_{}, b_grid_desc_k0_n0_n1_k1_{}, ds_grid_desc_m0_m10_m11_n0_n10_n11_{}, @@ -514,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK using DLayout = remove_cvref_t>; using DDataType = remove_cvref_t>; + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + // D pointer p_ds_grid_(i) = static_cast(p_ds[i]); @@ -521,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); }); // populate desc for Ds/E @@ -566,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // tensor descriptors for problem definiton index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; DsGridDesc_M_N ds_grid_desc_m_n_; @@ -581,7 +528,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DefaultBlock2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -645,7 +592,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, DefaultBlock2CTileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, has_double_loop>; @@ -710,11 +657,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK namespace ctc = tensor_layout::convolution; // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || - ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || - ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")) + if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } @@ -892,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK cde_element_op}; } + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( @@ -936,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK cde_element_op); } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index f18fbcfe4b51031009f6329baac2bc1ac209e1fa..50e171e503cd10643623e133391dbbbc42f01d9e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -106,8 +106,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -234,36 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; template static auto - MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - c_g_n_k_wos_lengths, - c_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -282,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd static auto - MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -305,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd - static auto - MakeCGridDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(c_g_n_k_wos_lengths, - c_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -320,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t({}, {}))>; - using CGridDesc_M_N = remove_cvref_t({}, {}))>; + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using CGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // GridwiseGemm using GridwiseGemm = @@ -395,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd(p_b)}, p_c_grid_{static_cast(p_c)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, a_grid_desc_ak0_m_ak1_{ - DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - c_g_n_k_wos_lengths, - c_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1( - b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(c_g_n_k_wos_lengths, - c_g_n_k_wos_strides)}, + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + c_grid_desc_m_n_{ + DeviceOp::MakeCGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_k0_m0_m1_k1_{}, b_grid_desc_k0_n0_n1_k1_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{}, @@ -472,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); + + if constexpr(is_same_v) + { + a_element_op.InitUnaryOpPtrOnDevice(); + } + if constexpr(is_same_v) + { + b_element_op.InitUnaryOpPtrOnDevice(); + } + if constexpr(is_same_v) + { + cde_element_op.InitUnaryOpPtrOnDevice(); + } + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); + } + + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + + GridwiseGemm::template Run( + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler(), + index_t NumGroupsToMerge = 1> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static_assert(NumGroupsToMerge >= 1); + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + // NGCHW is not supported for multiAB + static_assert(!(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) || + !(isMultiA || isMultiB)); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKYXC_NGKHW(), + ctc::NHWGC, + std::conditional_t(), + ctc::NDHWGC, + ALay>>; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKYXC_NGKHW(), + ctc::NHWGK, + std::conditional_t(), + ctc::NDHWGK, + ELay>>; + + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + + // Argument + struct Argument : public BaseArgument + { + Argument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides)}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + num_group_{a_g_n_c_wis_lengths_[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideA_(i) = + a_g_n_c_wis_strides_[0] * NumGroupsToMerge; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + compute_ptr_offset_of_n_.BatchStrideA_(i) = + a_g_n_c_wis_strides_[1] * conv_N_per_block_; + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides_[1] * conv_N_per_block_; + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideB_(i) = + b_g_k_c_xs_strides_[0] * NumGroupsToMerge; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_groups_.BatchStrideA_ = + a_g_n_c_wis_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = + b_g_k_c_xs_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides_[1] * conv_N_per_block_; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + ds_g_n_k_wos_strides_[i], + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_groups_.BatchStrideE_ = + e_g_n_k_wos_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceSizeBytes() const + { + // Transpose require workspace for A and B + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes(); + } + else + { + return 0; + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + index_t conv_N_per_block_; + + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + const index_t gdy = arg.num_group_ / NumGroupsToMerge; + const index_t gdz = num_workgroups_per_Conv_N; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if constexpr(isMultiA || isMultiB) + { + // Generate tuples with grid descriptors for each A and B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number{}); + + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AGridPointer, + BGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + decltype(as_grid_desc_ak0_m_ak1), + decltype(bs_grid_desc_bk0_n_bk1), + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + else + { + const ADataType* p_a_grid = arg.p_as_grid_.At(I0); + EDataType* p_e_grid = arg.p_e_grid_; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + } + + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ADataType*, + const BDataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, // Pass just A descriptor instead of tuple + arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple + arg.p_ds_grid_, + p_e_grid, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.p_as_grid_.At(I0)), + make_tuple(p_a_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + element_wise::PassThrough{}); + } + + avg_time += RunGemm(arg, stream_config); + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_out_grid = + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + + EDataType* p_e_in_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_out_grid), + make_tuple(p_e_in_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK()) + { + return false; + } + } + + if constexpr(NumGroupsToMerge > 1) + { + if(!(C == 1)) + { + return false; + } + if(G % NumGroupsToMerge != 0) + { + return false; + } + if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK() || + is_NGCSpatial_GKSpatial_NGKSpatial())) + { + return false; + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + // Check access per C + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + // If not possible, check access per G + if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) && + (is_NSpatialGC_GKSpatial_NSpatialGK() || + is_NGCSpatial_GKSpatial_NGKSpatial()) && + G % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } + } + else + { + valid = false; + } + }); + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << NumGroupsToMerge + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..589a0daa99d4070ebf72d64961e239cfdc7c5488 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,1562 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); + const index_t& num_blocks_per_n = groups_count; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); + const index_t& num_blocks_per_n = groups_count; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiD = DsDataType::Size() > 0; + static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + + // multi ABD not supported + static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported"); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKYXC_NGKHW(), + ctc::NHWGC, + std::conditional_t(), + ctc::NDHWGC, + ALay>>; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKYXC_NGKHW(), + ctc::NHWGK, + std::conditional_t(), + ctc::NDHWGK, + ELay>>; + + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + +#define GridwiseGemmV3TemplateParams \ + tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ + tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ + MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ + AComputeDataType, BComputeDataType + + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const index_t M = e_grid_desc_m_n.GetLength(I0); + const index_t N = e_grid_desc_m_n.GetLength(I1); + return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); + } + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_as, + const void* p_bs, + const std::array&, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{}, + p_b_grid_{}, + p_e_grid_{static_cast(p_e)}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides)}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + num_group_{a_g_n_c_wis_lengths_[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + a_grid_desc_ak0_m_ak1_{ + MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // A/B/E Batch/N Stride + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; + + // p_as and p_bs are pointers + p_a_grid_ = static_cast(p_as); + p_b_grid_ = static_cast(p_bs); + + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceSizeBytes() const + { + // Transpose require workspace for A and B + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes(); + } + else + { + return 0; + } + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + index_t conv_N_per_block_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_M_N e_grid_desc_m_n_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // block-to-e-tile map + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); + + gdy *= arg.num_group_ * num_workgroups_per_Conv_N; + + index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const ADataType* p_a_grid = arg.p_a_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + } + + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1}; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + arg.num_group_); + } + else + { + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + arg.num_group_); + } + }; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(p_a_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + element_wise::PassThrough{}); + } + + avg_time += RunGemm(arg, stream_config); + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_out_grid = + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + + EDataType* p_e_in_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_out_grid), + make_tuple(p_e_in_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise gemm v3 doesn't verify descriptors size + if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) + { + return false; + } + + // check Gridwise GEMM + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index caa18b709cea604930c3c81c80cdbebc5d9cfaa1..58f1396710287b532d89a003dc226cabe3a374e1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,16 +156,16 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -309,36 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -347,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); @@ -362,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -446,11 +420,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); } - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; - using RGridDesc_M = remove_cvref_t({}, {}))>; + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using RGridDesc_M = remove_cvref_t({}, {}))>; // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< @@ -550,21 +527,23 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, p_rs_grid_{}, // FIXME - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, r_grid_desc_m_{ DeviceOp::MakeRGridDescriptor_M(r_g_n_wos_lengths, r_g_n_wos_strides)}, a_grid_desc_ak0_m_ak1_{ @@ -620,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // D batch stride compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -659,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle EDataType* p_e_grid_; typename GridwiseGemm::RsGridPointer p_rs_grid_; + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + // tensor descriptors for problem definiton AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; @@ -813,8 +805,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || - get_device_name() == "gfx941" || get_device_name() == "gfx942") + else if(ck::is_lds_direct_load_supported()) { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) @@ -834,7 +825,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -851,7 +842,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -1090,7 +1081,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 7c726cd85128ff61ac4eb675d5c3dee9e9de3315..1c9f6b00948b25b5f2b1a7e045917c65ce4b1402 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,10 +15,11 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -27,55 +28,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - -} // namespace - // // @brief Device Convolution operation. // @@ -100,22 +52,23 @@ template struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle - : public DeviceGroupedConvFwdMultipleD + : public DeviceGroupedConvFwdMultipleABD { using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr index_t KPerBlock = K0PerBlock * K1; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = + AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1); + static constexpr auto BEnableLds = + BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + conv_to_gemm_transformer.template MakeADescriptor_M_K(); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - return in_gemmm_gemmk_desc; + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - return wei_gemmn_gemmk_desc; + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -230,68 +240,30 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle return out_gemmm_gemmn_desc; } - static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); }, Number{}); } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; - - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK1 = K1; - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK1 = K1; - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{})); - using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{})); + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc = + decltype(DeviceOp::MakeAGridDescriptor(dummy_conv_to_gemm_transformer)); + using BGridDesc = + decltype(DeviceOp::MakeBGridDescriptor(dummy_conv_to_gemm_transformer)); + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -300,8 +272,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle DsDataType, EDataType, // InMemory Data Descriptor - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, + AGridDesc, + BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, // ElementwiseOp Family @@ -312,9 +284,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, - MPerWMMA, - NPerWMMA, + KPerBlock, + MPerWmma, + NPerWmma, K1, MRepeat, NRepeat, @@ -327,6 +299,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, + AEnableLds, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, @@ -335,6 +308,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, + BEnableLds, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, @@ -375,23 +349,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_{DeviceOp::MakeAGridDescriptor(conv_to_gemm_transformer_)}, + b_grid_desc_{DeviceOp::MakeBGridDescriptor(conv_to_gemm_transformer_)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, @@ -430,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle }); // D desc - ds_grid_desc_m_n_ = - DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); + ds_grid_desc_m_n_ = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }, + Number{}); // populate desc for Ds/E e_grid_desc_mblock_mperblock_nblock_nperblock_ = @@ -443,8 +431,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle void Print() const { - std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; - std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + std::cout << "A[M, K]: " << a_grid_desc_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; @@ -459,14 +447,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // tensor descriptors for problem definiton index_t num_group_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -476,7 +465,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -513,13 +502,22 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< GridwiseOp, ADataType, BDataType, @@ -528,12 +526,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel(stream_config, @@ -549,8 +547,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle arg.b_element_op_, arg.cde_element_op_, arg.a_g_n_c_wis_lengths_[0], // Group count - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, @@ -579,8 +577,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle namespace ctc = tensor_layout::convolution; // check device - if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -599,7 +596,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -616,7 +613,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -679,8 +676,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; @@ -720,8 +716,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // check Gridwise GEMM - return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + return GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, arg.block_2_etile_map_); @@ -776,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle cde_element_op}; } + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + static auto MakeInvoker() { return Invoker{}; } std::unique_ptr MakeArgumentPointer( @@ -822,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle cde_element_op); } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); @@ -840,9 +986,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle << KPerBlock << ", " << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "ABlockTransferSrcScalarPerVector: " << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector - << ">"; + << "BBlockTransferSrcScalarPerVector: " + << BBlockTransferSrcScalarPerVector; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index c80598c4e335aaf8f448d3ad1af6c0e6172768d1..09ff8207574bfdac0c5d8f4cf666149ccde48860 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -1,202 +1,22 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include -#include -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/io.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace { - -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - -/* - * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. - * - * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix - * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly - * strided batched, but we can easily extend to other layouts. The returned offset can be either \p - * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB - * limitations. - * - * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and - * returns the 2D index of the tile that it computes. \see - * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). - * - * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 - * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid - * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link - * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for - * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the - * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. - * - * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. - * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to - * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). - * - */ -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); -#else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = compute_ptr_offset_of_batch; - ignore = block_2_ctile_map; -#endif -} - -} // namespace - // // @brief Device Convolution operation. -// +// @note This structure is deprecated (left for backwards compatibility). Please use +// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle. // Supports: // @li Forward convolution with up to 3 spatial dimentions // @li Input tensor in GNWC data format @@ -255,711 +75,63 @@ template -struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle - : public DeviceGroupedConvFwdMultipleD -{ - using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; - - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; - - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - - template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - const auto in_gemmm_gemmk_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - - return in_gemmm_gemmk_desc; - } - - template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) - { - const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); - - const auto wei_gemmn_gemmk_desc = - matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - - return wei_gemmn_gemmk_desc; - } - - template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) - { - const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); - - const auto out_gemmm_gemmn_desc = - matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); - - return out_gemmm_gemmn_desc; - } - - static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) - { - return generate_tuple( - [&](auto i) { - using DLayout = remove_cvref_t>; - - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i]); - }, - Number{}); - } - - // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; - - // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = - remove_cvref_t; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - - // block-to-e-tile map - using Block2ETileMap = - remove_cvref_t; - - // Argument - struct Argument : public BaseArgument - { - Argument(const void* p_a, - const void* p_b, - const std::array& p_ds, - void* p_e, - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& - ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& - ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) - : p_a_grid_{static_cast(p_a)}, - p_b_grid_{static_cast(p_b)}, - p_ds_grid_{}, - p_e_grid_{static_cast(p_e)}, - num_group_{a_g_n_c_wis_lengths[0]}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, - ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - compute_ptr_offset_of_batch_{}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op}, - a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, - a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, - b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, - ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, - ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, - e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, - e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, - conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, - input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} - { - // A/B/E Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - - // populate pointer, batch stride, desc for Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - using DDataType = remove_cvref_t>; - - // D pointer - p_ds_grid_(i) = static_cast(p_ds[i]); - - // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; - - // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); - }); - - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - } - } - - void Print() const - { - std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; - std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); - std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; - } - - // private: - // pointers - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - EDataType* p_e_grid_; - - // tensor descriptors for problem definiton - index_t num_group_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; - DsGridDesc_M_N ds_grid_desc_m_n_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - - // for checking IsSupportedArgument() - std::array a_g_n_c_wis_lengths_; - std::array a_g_n_c_wis_strides_; - std::array b_g_k_c_xs_lengths_; - std::array b_g_k_c_xs_strides_; - std::array, NumDTensor> ds_g_n_k_wos_lengths_; - std::array, NumDTensor> ds_g_n_k_wos_strides_; - std::array e_g_n_k_wos_lengths_; - std::array e_g_n_k_wos_strides_; - std::array conv_filter_strides_; - std::array conv_filter_dilations_; - std::array input_left_pads_; - std::array input_right_pads_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; - - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - - const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - Block2ETileMap, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); - }; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - namespace ctc = tensor_layout::convolution; - - // check device - if(get_device_name() == "gfx908") - { - if constexpr(!(is_same_v || is_same_v || - is_same_v)) - { - return false; - } - } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || - get_device_name() == "gfx941" || get_device_name() == "gfx942") - { - if constexpr(!(is_same_v || is_same_v || - is_same_v || is_same_v)) - { - return false; - } - } - else - { - return false; - } - - // check ConvolutionForwardSpecialization - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - // check if it's 1x1, stride=1 conv - for(index_t i = 0; i < NDimSpatial; ++i) - { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; - const index_t ConvStride = arg.conv_filter_strides_[i]; - const index_t LeftPad = arg.input_left_pads_[i]; - const index_t RightPad = arg.input_right_pads_[i]; - - if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) - { - return false; - } - } - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // check if it's 1x1 conv - for(index_t i = 0; i < NDimSpatial; ++i) - { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; - const index_t LeftPad = arg.input_left_pads_[i]; - const index_t RightPad = arg.input_right_pads_[i]; - - if(!(X == 1 && LeftPad == 0 && RightPad == 0)) - { - return false; - } - } - } - - // check vector access of A - // FIXME: layout - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - { - const index_t C = arg.a_g_n_c_wis_lengths_[2]; - - if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - return false; - } - - // check vector access of B - // FIXME: layout - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - - { - const index_t C = arg.b_g_k_c_xs_lengths_[2]; - - if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - return false; - } - - // check vector access of Ds - bool valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - // FIXME: layout - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - { - const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; - - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) - { - valid = false; - } - } - else - { - valid = false; - } - }); - - if(!valid) - { - return false; - } - - // check vector access of E - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - { - const index_t K = arg.e_g_n_k_wos_lengths_[2]; - - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - } - else - { - return false; - } - - // check Gridwise GEMM - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument( - const void* p_a, - const void* p_b, - const std::array& p_ds, - void* p_e, - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) - { - return Argument{p_a, - p_b, - p_ds, - p_e, - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_k_wos_lengths, - ds_g_n_k_wos_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - std::unique_ptr MakeArgumentPointer( - const void* p_a, - const void* p_b, - const std::array& p_ds, - void* p_e, - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) override - { - return std::make_unique(p_a, - p_b, - p_ds, - p_e, - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_k_wos_lengths, - ds_g_n_k_wos_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_element_op, - b_element_op, - cde_element_op); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " - << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle - << ">"; - // clang-format on - - return str.str(); - } -}; + typename AComputeDataType = + decltype(UnpackDataType::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> +using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ConvForwardSpecialization, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + AComputeDataType, + BComputeDataType, + LoopSched>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc7f13441acb8ec99caf7b6625abade579204ac7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,1054 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && + block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, + gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, + Tuple<>{}, + gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + Tuple<>{}, + gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].block_2_etile_map_); +#else + ignore = gemm_desc_kernel_args; + ignore = gemms_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; +#endif +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t MaxGemmsNum = 32; + static_assert(NumDTensor == 0, "MultiD not supported."); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm; + + using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + static auto + GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, + const ADataType* a_grid_ptr_base, + EDataType* c_grid_ptr_base) + { + // Max number of splits + // We need to use it to avoid infinity loop + constexpr index_t max_split_numbers = MaxGemmsNum / 2; + // Arrays to store transformers with smaller descs than 2GB + Array conv_to_gemm_transformers_arr; + Array a_grid_ptrs_arr; + Array c_grid_ptrs_arr; + // Queue for spliting + std::queue conv_to_gemm_transformers_queue( + {conv_to_gemm_transformer_base}); + std::queue a_grid_ptrs_queue({a_grid_ptr_base}); + std::queue c_grid_ptrs_queue({c_grid_ptr_base}); + + index_t gemms_number = 0; + index_t split_numbers = 0; + // Algorithm: + // While queue is not empty: + // 1. Get transformer from queue. + // 2. If descs are smaller than 2GB push to result array. + // 3. If descs are bigger than 2GB split into left and right transformer. + // and push the both into the queue. + while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers && + gemms_number < MaxGemmsNum) + { + // Get transformer from the queue + const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front(); + const ADataType* a_grid_ptr = a_grid_ptrs_queue.front(); + EDataType* c_grid_ptr = c_grid_ptrs_queue.front(); + + // Check if convolution not exceed 2GB + if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB()) + { + // If yes, push into result array + conv_to_gemm_transformers_arr(gemms_number) = + ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer}; + a_grid_ptrs_arr(gemms_number) = a_grid_ptr; + c_grid_ptrs_arr(gemms_number) = c_grid_ptr; + gemms_number++; + } + else + { + // If no, split into left and right convolutions + ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part; + const ADataType* a_grid_right_ptr; + EDataType* c_grid_right_ptr; + + ck::tie(conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part, + a_grid_right_ptr, + c_grid_right_ptr) = + conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr); + + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part); + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part); + // Left offsets remain the same + a_grid_ptrs_queue.push(a_grid_ptr); + a_grid_ptrs_queue.push(a_grid_right_ptr); + c_grid_ptrs_queue.push(c_grid_ptr); + c_grid_ptrs_queue.push(c_grid_right_ptr); + split_numbers++; + } + // Remove from the queue + conv_to_gemm_transformers_queue.pop(); + a_grid_ptrs_queue.pop(); + c_grid_ptrs_queue.pop(); + } + + const bool is_split_valid = conv_to_gemm_transformers_queue.empty(); + + return ck::make_tuple(conv_to_gemm_transformers_arr, + a_grid_ptrs_arr, + c_grid_ptrs_arr, + gemms_number, + is_split_valid); + } + +#define GridwiseGemmTemplateParameters \ + ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + AComputeDataType + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + // Structure for each gemm(conv) + struct GemmArgs + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& /*p_ds*/, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + /*ds_g_n_k_wos_lengths*/, + const std::array, NumDTensor>& + /*ds_g_n_k_wos_strides*/, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : num_group_{static_cast(a_g_n_c_wis_lengths[0])}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array c_grid_ptrs; + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) + { + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) + { + const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); + + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + Tuple<>{}, + e_grid_desc_m_n, + block_2_etile_map)) + { + + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } + } + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); + } + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + } + + void Print() const + { + for(index_t i = 0; i < valid_gemms_count_; i++) + { + std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_ + << std::endl; + std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_ + << std::endl; + std::cout + << "E[MBlock, MPerBlock, NBlock, NPerBlock]: " + << gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_ + << std::endl; + } + } + + index_t num_group_; + index_t conv_N_per_block_; + + Array gemm_desc_kernel_args_; + + index_t grid_size_; + index_t gemms_count_; + index_t valid_gemms_count_; + + bool is_split_valid_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; + + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; + + // Check if all descs are valid + if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_)) + { + return false; + } + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK()) + { + return false; + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + // Check access per C + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i64; + std::array a_g_n_c_wis_strides_i64; + std::array b_g_k_c_xs_lengths_i64; + std::array b_g_k_c_xs_strides_i64; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i64; + std::array, NumDTensor> ds_g_n_k_wos_strides_i64; + std::array e_g_n_k_wos_lengths_i64; + std::array e_g_n_k_wos_strides_i64; + std::array conv_filter_strides_i64; + std::array conv_filter_dilations_i64; + std::array input_left_pads_i64; + std::array input_right_pads_i64; + + array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i64, conv_filter_strides); + array_convert(conv_filter_dilations_i64, conv_filter_dilations); + array_convert(input_left_pads_i64, input_left_pads); + array_convert(input_right_pads_i64, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i64, + a_g_n_c_wis_strides_i64, + b_g_k_c_xs_lengths_i64, + b_g_k_c_xs_strides_i64, + ds_g_n_k_wos_lengths_i64, + ds_g_n_k_wos_strides_i64, + e_g_n_k_wos_lengths_i64, + e_g_n_k_wos_strides_i64, + conv_filter_strides_i64, + conv_filter_dilations_i64, + input_left_pads_i64, + input_right_pads_i64, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + std::array a_g_n_c_wis_lengths_i64; + std::array a_g_n_c_wis_strides_i64; + std::array b_g_k_c_xs_lengths_i64; + std::array b_g_k_c_xs_strides_i64; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i64; + std::array, NumDTensor> ds_g_n_k_wos_strides_i64; + std::array e_g_n_k_wos_lengths_i64; + std::array e_g_n_k_wos_strides_i64; + std::array conv_filter_strides_i64; + std::array conv_filter_dilations_i64; + std::array input_left_pads_i64; + std::array input_right_pads_i64; + + array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i64, conv_filter_strides); + array_convert(conv_filter_dilations_i64, conv_filter_dilations); + array_convert(input_left_pads_i64, input_left_pads); + array_convert(input_right_pads_i64, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i64, + a_g_n_c_wis_strides_i64, + b_g_k_c_xs_lengths_i64, + b_g_k_c_xs_strides_i64, + ds_g_n_k_wos_lengths_i64, + ds_g_n_k_wos_strides_i64, + e_g_n_k_wos_lengths_i64, + e_g_n_k_wos_strides_i64, + conv_filter_strides_i64, + conv_filter_dilations_i64, + input_left_pads_i64, + input_right_pads_i64, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3bcd8859aaaf136f081acbd12adb0a714107400f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1d +template +constexpr bool is_NWGC_GKXC_NWGK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_GNWC_GKXC_GNWK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCW_GKXC_NGKW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +// 2d +template +constexpr bool is_NHWGC_GKYXC_NHWGK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_GNHWC_GKYXC_GNHWK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCHW_GKYXC_NGKHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} +// 3d +template +constexpr bool is_NDHWGC_GKZYXC_NDHWGK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_GNDHWC_GKZYXC_GNDHWK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCDHW_GKZYXC_NGKDHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK() +{ + return is_NWGC_GKXC_NWGK() || + is_NHWGC_GKYXC_NHWGK() || + is_NDHWGC_GKZYXC_NDHWGK(); +} + +template +constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK() +{ + return is_GNWC_GKXC_GNWK() || + is_GNHWC_GKYXC_GNHWK() || + is_GNDHWC_GKZYXC_GNDHWK(); +} + +template +constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial() +{ + return is_NGCW_GKXC_NGKW() || + is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW(); +} + +template +struct ComputePtrOffsetOfStridedBatch +{ +}; + +template +struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + long_index_t BatchStrideE) + : BatchStrideA_(BatchStrideAs), + BatchStrideB_(BatchStrideBs), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const + { + Array as_offset; + static_for<0, NumATensor, 1>{}( + [&](auto i) { as_offset(i) = static_cast(g_idx) * BatchStrideA_[i]; }); + return as_offset; + } + + __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const + { + Array bs_offset; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { bs_offset(i) = static_cast(g_idx) * BatchStrideB_[i]; }); + return bs_offset; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +struct ComputePtrOffsetOfStridedBatch> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + Array BatchStrideDs, + long_index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideA_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideB_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +constexpr static auto GetNumABTensors() +{ + if constexpr(isTuple) + { + return Number{}; + } + else + { + return Number<1>{}; + } +} + +template +constexpr static auto GetAGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::AsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto GetBGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::BsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto UnpackDataType() +{ + if constexpr(isTuple) + { + // unpack if tuple + return tuple_element_t{}; + } + else + { + // if no, return Type + return Type{}; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21afc0604068ff9fa89badb48ef77018f5e38256 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -0,0 +1,850 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t KBatch = 1; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; + const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + GridwiseGemm:: + template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = grid_size_grp; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + : public DeviceGroupedGemmMultiABDFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t NumGemmKPrefetchStage = 1; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple( + // idx_bot[Number<0>{}], + idx_bot[Number<1>{}], + idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + std::array as_ptr_; + std::array bs_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + std::array StrideAs_; + std::array StrideBs_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t) {} + + Argument(std::vector>&, + std::vector>&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + // pointer + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + + static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + + const index_t StrideE = gemm_descs[i].stride_C_; + + if(gemm_descs[i].stride_As_.size() != NumATensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); + } + + static_for<0, NumATensor, 1>{}( + [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; }); + + if(gemm_descs[i].stride_Bs_.size() != NumBTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); + } + + static_for<0, NumBTensor, 1>{}( + [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; }); + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + static_for<0, NumDTensor, 1>{}( + [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + p_as_grid, + p_bs_grid, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_ = 1; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = kernel_grouped_gemm_xdl_fixed_nk< + GridwiseGemm, + GroupedGemmMultiABDKernelArgument, + GemmSpec, + AsLayout, + BsLayout, + DsLayout, + ELayout, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetElementwiseOps(Argument& arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + arg.a_element_op_ = a_element_op; + arg.b_element_op_ = b_element_op; + arg.c_element_op_ = c_element_op; + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) const override + { + + SetElementwiseOps( + *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * + sizeof(GroupedGemmMultiABDKernelArgument); + } + +#if 0 + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } +#endif + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 0190b3cee6b44af7656c34be866c9ab0855f7b0d..060a16d1e21e9265481467bb6e03f0341c3cf886 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,9 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ + defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -554,24 +554,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template = false> +struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + : public DeviceGroupedGemmMultipleDSplitK +{ + using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1 + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using WorkspaceDataType = float; + + // First stage GridwiseGEMM kernel. + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + WorkspaceDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, // CElementwiseOperation + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeDataType>; + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&]([[maybe_unused]] auto i) constexpr { + return Number{}; + }, + Number{}); + } + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using DsGridPointer = decltype(MakeDsGridPointer()); + using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); + using CDDataTypes = decltype(concat_tuple(ck::Tuple{}, DsGridPointer{})); + + using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence()); + + static constexpr index_t ClusterLengthMPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + ElementwiseInputSequence, + ck::Sequence, + I1, + I1>; + + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using GemmKernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + GemmKernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op, + DefaultKBatch) + { + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t kbatch) + : K_BATCH{kbatch}, + group_count_{0}, + skipped_group_count_{0}, + grid_size_{0}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + p_Ds_{p_Ds} + { + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size"); + } + + gemm_kernel_args_.reserve(group_count_); + elementwise_c_grid_descs_m_n_.reserve(group_count_); + elementwise_d_grid_descs_m_n_.reserve(group_count_); + ds_grid_pointer_.reserve(group_count_); + group_grid_size_.reserve(group_count_); + e_ptrs_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0 || N == 0 || K == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_e = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + + DsGridDesc_M_N ds_grid_desc_m_n; + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + group_grid_size_.push_back(grid_size_grp); + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + std::array stride_ds; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + stride_ds[j] = gemm_descs[i].stride_Ds_[j]; + }); + stride_Ds_.emplace_back(std::move(stride_ds)); + + // We first set E pointer to actual operation output, but later on + // when workspace will be set, this will be updated to workspace memory. + auto karg = GemmKernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_e, + m_padded, + n_padded, + k_padded, + k0_padded, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); + elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); + ds_grid_pointer_.push_back(p_ds_grid); + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_.push_back(p_Es[i]); + } + } + + /** + * @brief Set new kbatch value. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + group_grid_size_[i] = grid_size_grp; + karg.KPadded = k_padded; + karg.K0Padded = k0_padded; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + index_t tiles = (block_end - block_start) / K_BATCH; + std::cout << "block_start: " << block_start << "\n" + << "block_end: " << block_end << "\n" + << "tiles: " << tiles << std::endl + << std::endl; + + std::cout << "KPadded: " << karg.KPadded << std::endl + << "K0Padded: " << karg.K0Padded << std::endl + << "KBatch: " << karg.k_batch << std::endl + << "grid_size_: " << karg.KPadded << std::endl; + } + } + } + + void UpdateEPointers() + { + // set-up each group E pointer to it's designated workspace memory. + WorkspaceDataType* p_workspace = reinterpret_cast(p_workspace_); + std::size_t offset = 0; + + for(auto& arg : gemm_kernel_args_) + { + arg.karg_.p_c_grid = p_workspace + offset; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + offset += tiles * MPerBlock * NPerBlock; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "block_start: " << arg.block_start_ << "\n" + << "block_end: " << arg.block_end_ << "\n" + << "tiles: " << tiles << "\n" + << "offset: " << offset << std::endl; + } + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + std::size_t size_bytes{0}; + + for(const auto& arg : gemm_kernel_args_) + { + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType); + } + return size_bytes; + } + + std::size_t GetWorkspaceSize(std::size_t group) const + { + const auto& arg = gemm_kernel_args_[group]; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + return tiles * MPerBlock * NPerBlock; + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + index_t grid_size_; + // Pointer to device memory with GEMM kernel arguments. + const void* p_dev_gemm_args_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector>& p_Ds_; + std::vector> stride_Ds_; + std::vector gemm_kernel_args_; + std::vector group_grid_size_; + + std::vector elementwise_c_grid_descs_m_n_; + std::vector elementwise_d_grid_descs_m_n_; + std::vector ds_grid_pointer_; + std::vector e_ptrs_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary + /// workspace. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config = StreamConfig{}) + { + auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = + CheckArgument(arg, stream_config); + + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(dev_gemm_workspace == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + + if(all_have_main_k_block_loop) + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + else + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see + /// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly + /// allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(arg.p_workspace_ == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const + { + bool all_have_kbatch_gt_one, all_have_main_k_block_loop; + + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1( + arg.gemm_kernel_args_[0].karg_.M, + arg.gemm_kernel_args_[0].karg_.MPadded, + arg.gemm_kernel_args_[0].karg_.K, + arg.gemm_kernel_args_[0].karg_.StrideA, + arg.gemm_kernel_args_[0].karg_.k_batch, + arg.gemm_kernel_args_[0].karg_.K0Padded, + arg.gemm_kernel_args_[0].karg_.KPadded); + + all_have_kbatch_gt_one = arg.K_BATCH > 1; + all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + } + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.StrideA, + gemm_arg.k_batch, + gemm_arg.K0Padded, + gemm_arg.KPadded); + + bool not_all_have_main_k_block_loop_same = + all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + bool not_all_have_kbatch_value_same = + all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); + + if(not_all_have_main_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << gemm_arg.k_batch + << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); + } + + template + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + const auto gemm_kernel = + kernel_grouped_gemm_xdl_splitk; + + const auto elementwise_kernel = kernel_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation>; + return LaunchKernel(gemm_kernel, + elementwise_kernel, + arg, + dev_gemm_args, + dev_gemm_workspace, + stream_config); + } + + template + float LaunchKernel(const KernelFunction& gemm_kernel, + const KernelFunction2& elementwise_kernel, + const Argument& arg, + const void* dev_gemm_args, + [[maybe_unused]] void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + float time{0.f}; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + // GEMM kernel + time = launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + gemm_kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.gemm_kernel_args_.size(), + arg.a_element_op_, + arg.b_element_op_, + PassThrough{}); + + // Elementwise kernels + for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + time += launch_and_time_kernel( + stream_config, + elementwise_kernel, + dim3(arg.group_grid_size_[i]), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + arg.elementwise_d_grid_descs_m_n_[i]), + make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), + arg.ds_grid_pointer_[i]), + type_convert(arg.e_ptrs_[i]), + Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0), + arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)}, + arg.cde_element_op_); + } + return time; + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + + bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpy(p_dev_kernel_args, + arg.gemm_kernel_args_.data(), + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + void SetWorkSpacePointer( + BaseArgument* p_arg, + void* p_workspace, + [[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + p_arg_->UpdateEPointers(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2884e558cd359494dad4119133348af61765e08a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -0,0 +1,957 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/stream_utility.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Entry point kernel for device-wide Grouped GEMM operation. +/// +/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. +/// @param[in] group_count The number of together processed GEMMs. +/// +/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation. +/// @tparam GemmDesc The structure holding all necessary descriptors and +/// other data needed for grouped gemm calculation and work +/// distribution. +/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids, +/// the data tiles to process and the output tiles. +/// +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared1[shared_size]; + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do + { + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) + { + group_offset += grid_size_grp; + group_id++; + + if(group_id >= group_count) + return; + + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + { + grid_size_grp = 0; + continue; + } + + b2c_tile_map = + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + static constexpr index_t kbatch = 1; + static constexpr index_t k_grain = kbatch * KPerBlock; + index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + gemm_desc_ptr[group_id].StrideA, + gemm_desc_ptr[group_id].StrideB, + gemm_desc_ptr[group_id].StrideDs, + gemm_desc_ptr[group_id].StrideE, + kbatch); + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + tile_id += get_grid_size(); + tile_offset += get_grid_size(); + + } while(group_id < group_count); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template + +struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop + : public DeviceGroupedGemmTileLoop +{ + using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using KernelArguments = GroupedGemmTileLoopKernelArguments; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& /* p_As */, + std::vector& /* p_Bs */, + std::vector>& /* p_Ds */, + std::vector& /* p_Es */, + const std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + int occupancy_num_blocks, + int gpu_cu_count) + : group_count_{static_cast(gemm_descs.size())}, + occupancy_num_blocks_{occupancy_num_blocks}, + gpu_cu_count_{gpu_cu_count}, + gemm_descs_{gemm_descs}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + tile_count_{0} + { + for(const auto& desc : gemm_descs) + { + const auto M = desc.M_; + const auto N = desc.N_; + const auto b2c_tile_map = Block2ETileMap(M, N); + tile_count_ += b2c_tile_map.CalculateGridSize(M, N); + } + } + + index_t group_count_; + const void* p_dev_gemm_args_; + int occupancy_num_blocks_; + int gpu_cu_count_; + const std::vector& gemm_descs_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + index_t tile_count_; + }; + + struct KernelConfig + { + // The oversubscription factor for the number of blocks that can simultaneously reside on + // GPU. + static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; + static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; + // Assume we want to have at most 2 waves per SIMD + static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config = StreamConfig{}) + { + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg + /// parameter to properly allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config) const + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); + } + + template + int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, + const StreamConfig& stream_config) const + { + // Calculate max number of workgroups that can simultaneously reside on the CU. + int occ_num_blocks = 0; + size_t dyn_shared_mem_per_blk = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk)); + + int cu_count = getAvailableComputeUnitCount(stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks + << ", available CUs count: " << cu_count << ", occup. grid size: " + << ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count + << std::endl; + } + + return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS); + } + + template + float LaunchKernel(const KernelFunction& kernel, + const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config) const + { + int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_ + << std::endl; + } + + // run multiple kernels + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + bool supported = true; + + constexpr index_t k_batch = 1; + for(index_t i = 0; i < arg.group_count_; ++i) + { + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + using GridArg = typename GridwiseGemm::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + + if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + supported = supported && GridwiseGemm::CheckValidity(gridwise_arg); + } + + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + int occupancy, num_cu; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + int occupancy, num_cu; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::ostringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(KernelArguments); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index ddd762af1c6dc2d6b2d45716b20f19a2a6e45207..871fbd019e7f68348399677e4c6e5a0ae7c22a52 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -44,7 +44,7 @@ __global__ void const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index db89bee96ce1f6530bd0ef2345f3655963291798..658f3235168f247ef888552a2a07a72f5e6fc0f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -39,7 +39,7 @@ __global__ void const CDEElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -228,9 +228,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -510,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + if constexpr(Zeroing) + { + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = barrier_count; + ignore = barrier_size_grp; + ignore = group_count; + ignore = grid_size_grp; + ignore = KBatch; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using AComputeType = ComputeType; + using BComputeType = ComputeType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + AComputeType, + BComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ALDSType, + BLDSType>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + const void* a_ptr_; + const void* b_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + index_t StrideA_, StrideB_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; + const index_t N = gemm_desc_kernel_arg_[0].N_; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + if(!GridwiseGemm:: + template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + false, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + nullptr, + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + true, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is + // enforced in IsSupportedArgument function + if constexpr(std::is_same::value) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by + // vector load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} + // layout, thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd + // instruction that supports bf16 and we cannot use splitk because of that + if constexpr(std::is_same::value) + { + supported = supported & (arg.k_batch_ == 1); + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * sizeof(GroupedGemmKernelArgument); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 41445aeaf076408adc7159359884803f0965e313..6d9d1459c8f54e6d506268cf21650dd1010d7d5d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,16 +26,22 @@ namespace device { template + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -64,10 +70,16 @@ __global__ void GridwiseGemm::template Run( gemm_desc_ptr[group_id].karg_, static_cast(p_shared), - gemm_desc_ptr[group_id].block_2_ctile_map_); + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); #else ignore = gemm_descs_const; ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; using KernelArgument = typename GridwiseGemm::Argument; - + using PassThrough = ck::tensor_operation::element_wise::PassThrough; struct GemmTransKernelArg { KernelArgument karg_; @@ -265,10 +277,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1; bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); @@ -384,7 +396,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1); @@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { -#if DEBUG_LOG - std::cout << "The group count is not equal to sum of skipped groups " - "and kernel args size!" - << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } return false; } @@ -529,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = QueryGroupNumber; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB0BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB1BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = alpha; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceGroupedQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(arg.G1_ % QueryGroupNumber != 0) + { + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_grouped_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGroupedQueryAttentionForward_Wmma, " + << "QueryGroupNumber: " + << QueryGroupNumber << ", " + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..648736fcbfafe0cf1ee987521b91af8ac3c876c2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Image to column: +// input : input image [G, N, Di, Hi, Wi, C] +// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C] +// input : input image [N, Di, Hi, Wi, G, C] +// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceImageToColumnImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ConvToGemmFwdTransformer = + TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + constexpr index_t spatial_offset = 3; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& gemm_g_m_k_strides) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2])); + return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + + in_grid_desc_m_k_ = MakeInputDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + out_grid_desc_m_k_ = MakeOutDescriptor_M_K( + N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides); + + compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0]; + } + + void Print() const + { + std::cout << in_grid_desc_m_k_ << std::endl; + std::cout << out_grid_desc_m_k_ << std::endl; + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + InputGridDesc in_grid_desc_m_k_; + OutputGridDesc out_grid_desc_m_k_; + + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_, + arg.p_in_, + arg.out_grid_desc_m_k_, + arg.p_out_, + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_, + arg.out_grid_desc_m_k_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceImageToColumn" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp deleted file mode 100644 index 175994d49f3a3121e2e32a4ca2ef07e60c0e52e5..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp +++ /dev/null @@ -1,316 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/stream_utility.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// output[indices] = input -template -struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd -{ - using DInDataType_AutomicAddPreCast = - conditional_t || is_same_v, - DInDataType, - float>; - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; - - static constexpr auto I0 = Number<0>{}; - - template - static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step) - { - const auto m = desc_m.GetLength(I0); - const auto pad = math::integer_least_multiple(m, loop_step) - m; - const auto desc_m_pad = - transform_tensor_descriptor(desc_m, - make_tuple(make_right_pad_transform(m, pad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return desc_m_pad; - } - - static auto MakeDescriptor_M(index_t length, index_t loop_step) - { - const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); - return PadDescriptor_M_1d(desc_m, loop_step); - } - - using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1)); - - using GridwisePutElementSet = GridwisePutElement_1D; - - using GridwisePutElementAtomicAdd = GridwisePutElement_1D; - - using GridwiseCasting = GridwiseElementwise_1D, - Tuple, - Tuple, - Tuple, - UnaryConvert, - InOutVectorSize, - Sequence, - Sequence>; - - struct Argument : public BaseArgument - { - Argument(const DOutDataType* p_dout, - const IndexDataType* p_indices, - DInDataType* p_din, - index_t dout_length, - index_t din_length, - const std::vector& window_lengths, - const std::vector& window_strides) - : p_dout_{p_dout}, - p_indices_{p_indices}, - p_din_{p_din}, - dout_length_raw_{dout_length}, - din_length_raw_{din_length}, - blockSize_{256}, - windowOverlap_{false} - { - for(size_t i = 0; i < window_lengths.size(); ++i) - { - windowOverlap_ |= window_lengths.at(i) > window_strides.at(i); - } - } - - const DOutDataType* p_dout_; - const IndexDataType* p_indices_; - DInDataType* p_din_; - index_t dout_length_raw_; - index_t din_length_raw_; - index_t blockSize_; - bool windowOverlap_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - index_t gridSize = getAvailableComputeUnitCount(stream_config); - index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize; - InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step); - InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step); - - if constexpr(is_same_v || is_same_v) - { - hip_check_error(hipMemsetAsync(arg.p_din_, - 0, - arg.din_length_raw_ * sizeof(DInDataType), - stream_config.stream_id_)); - - if(arg.windowOverlap_) - { - const auto put_kernel = kernel_put_element_1d; - - return launch_and_time_kernel(stream_config, - put_kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - dout_grid_desc, - arg.p_dout_, - arg.p_indices_, - arg.p_din_, - PassThrough{}); - } - else - { - const auto put_kernel = kernel_put_element_1d; - - return launch_and_time_kernel(stream_config, - put_kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - dout_grid_desc, - arg.p_dout_, - arg.p_indices_, - arg.p_din_, - PassThrough{}); - } - } - else - { - if(arg.windowOverlap_) - { - if(arg.p_workspace_ == nullptr) - throw std::runtime_error("wrong! WorkSpace pointer has not been set"); - - hip_check_error( - hipMemsetAsync(arg.p_workspace_, - 0, - arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast), - stream_config.stream_id_)); - - const auto put_kernel = kernel_put_element_1d; - - const auto cast_kernel = - kernel_elementwise_1d, - Tuple, - Tuple, - Tuple, - UnaryConvert>; - - float elapsed_time = launch_and_time_kernel( - stream_config, - put_kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - dout_grid_desc, - arg.p_dout_, - arg.p_indices_, - static_cast(arg.p_workspace_), - PassThrough{}); - - elapsed_time += launch_and_time_kernel( - stream_config, - cast_kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - ck::make_tuple(din_grid_desc), - ck::make_tuple(din_grid_desc), - static_cast(arg.p_workspace_), - arg.p_din_, - UnaryConvert{}); - - return elapsed_time; - } - else - { - const auto put_kernel = kernel_put_element_1d; - - hip_check_error(hipMemsetAsync(arg.p_din_, - 0, - arg.din_length_raw_ * sizeof(DInDataType), - stream_config.stream_id_)); - - return launch_and_time_kernel(stream_config, - put_kernel, - dim3(gridSize), - dim3(arg.blockSize_), - 0, - dout_grid_desc, - arg.p_dout_, - arg.p_indices_, - arg.p_din_, - PassThrough{}); - } - } - } - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - size_t GetWorkSpaceSize(const BaseArgument* pArg) const override - { - const Argument* pArg_ = dynamic_cast(pArg); - - bool needCast = pArg_->windowOverlap_ && - !(is_same_v || is_same_v); - - if(!needCast) - return 0; - else - return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast); - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - if(pArg->din_length_raw_ % InOutVectorSize != 0 || - pArg->dout_length_raw_ % InOutVectorSize != 0) - { - return false; - } - return true; - } - - std::unique_ptr - MakeArgumentPointer(const void* p_dout, - const void* p_indices, - void* p_din, - index_t dout_length, - index_t din_length, - std::vector window_lengths, - std::vector window_strides) override - { - // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are - // physical size of the packed tensor - return std::make_unique(static_cast(p_dout), - static_cast(p_indices), - static_cast(p_din), - dout_length, - din_length, - window_lengths, - window_strides); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f86b181e7e12f13c1616e978a8f5d06def0bb0ce --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd +{ + using DInDataType_AutomicAddPreCast = + conditional_t || is_same_v, + DInDataType, + float>; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template + static auto PadDescriptor_M_1d(Desc_M& desc_m, index_t loop_step) + { + const auto m = desc_m.GetLength(I0); + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t loop_step) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, loop_step); + } + + template + static auto ExpendDescFirstDim(Desc_M desc_m) + { + return transform_tensor_descriptor( + desc_m, + make_tuple(make_unmerge_transform(make_tuple(I1, desc_m.GetLength(I0)))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + + using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1)); + using InOutGrid2dDesc = decltype(ExpendDescFirstDim(InOutGrid1dDesc{})); + + using GridwisePutElementSet = GridwisePutElement_1D; + + using GridwisePutElementAtomicAdd = GridwisePutElement_1D; + + static constexpr index_t BlockSize = 256; + static constexpr index_t MPerThread = 1; + static constexpr index_t NPerThread = InOutVectorSize; + static constexpr index_t MPerBlock = 1; + static constexpr index_t NPerBlock = BlockSize * NPerThread; + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseCasting = GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + UnaryConvert, + BlockSize, + MPerBlock, + NPerBlock, + MPerThread, + NPerThread, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + const IndexDataType* p_indices, + DInDataType* p_din, + index_t dout_length, + index_t din_length, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations) + : p_dout_{p_dout}, + p_indices_{p_indices}, + p_din_{p_din}, + dout_length_raw_{dout_length}, + din_length_raw_{din_length}, + blockSize_{BlockSize}, + windowOverlap_{false} + { + for(size_t i = 0; i < window_lengths.size(); ++i) + { + auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1; + windowOverlap_ |= eff > window_strides.at(i); + } + } + + const DOutDataType* p_dout_; + const IndexDataType* p_indices_; + DInDataType* p_din_; + index_t dout_length_raw_; + index_t din_length_raw_; + index_t blockSize_; + bool windowOverlap_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize; + InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step); + InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step); + + if constexpr(is_same_v || is_same_v) + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + if(arg.windowOverlap_) + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + else + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + else + { + if(arg.windowOverlap_) + { + if(arg.p_workspace_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + hip_check_error( + hipMemsetAsync(arg.p_workspace_, + 0, + arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + const auto cast_kernel = + kernel_elementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + UnaryConvert>; + + float elapsed_time = launch_and_time_kernel( + stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + static_cast(arg.p_workspace_), + PassThrough{}); + + InOutGrid2dDesc din_grid_desc_2d = ExpendDescFirstDim(din_grid_desc); + const index_t M = din_grid_desc_2d.GetLength(I0); + const index_t N = din_grid_desc_2d.GetLength(I1); + const auto block_2_tile_map = Block2TileMap(M, N); + const auto cast_kernel_grid_size = + block_2_tile_map.CalculateGridSize(din_grid_desc_2d); + + elapsed_time += launch_and_time_kernel( + stream_config, + cast_kernel, + dim3(cast_kernel_grid_size), + dim3(arg.blockSize_), + 0, + ck::make_tuple(din_grid_desc_2d), + ck::make_tuple(din_grid_desc_2d), + ck::make_tuple( + static_cast(arg.p_workspace_)), + ck::make_tuple(arg.p_din_), + block_2_tile_map, + UnaryConvert{}); + + return elapsed_time; + } + else + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + bool needCast = pArg_->windowOverlap_ && + !(is_same_v || is_same_v); + + if(!needCast) + return 0; + else + return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast); + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + if(pArg->din_length_raw_ % InOutVectorSize != 0 || + pArg->dout_length_raw_ % InOutVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) override + { + // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are + // physical size of the packed tensor + return std::make_unique(static_cast(p_dout), + static_cast(p_indices), + static_cast(p_din), + dout_length, + din_length, + window_lengths, + window_strides, + window_dilations); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc88c1a10473a86dc7fd0bdc54c68f25c17a9682 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -0,0 +1,1245 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = 1; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = alpha; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceMultiQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceMultiQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_multi_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceMultiQueryAttentionForward_Wmma" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86689af0b7021a8946717b5e7a1ac23559783aa2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp @@ -0,0 +1,465 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { +template +__global__ void +kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M_K dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) +{ + GridwiseNormalizationBwd::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dx_grid_desc_m_k, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_gamma_global, + p_mean_global, + p_inv_std_global, + p_dx_global); +}; + +template +struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) || + (DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto Make2dDescriptor(const std::vector& lengths, + const std::vector& strides, + int numBlockTileIteration) + { + const auto tupleLengths = make_tuple_from_array(lengths, Number{}); + const auto tupleStrides = make_tuple_from_array(strides, Number{}); + + const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(lengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + + return transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, pad_M), + make_right_pad_transform(reduceLength, pad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k_padded; + } + + using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)); + + using GridwiseNormalizationBwdDataGeneric = + GridwiseNormalizationBwdData_mk_to_mk; + + using GridwiseNormalizationBwdDataSweepOnce = + GridwiseNormalizationBwdData_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const GammaDataType* p_gamma, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DXDataType* p_dx) + : p_dy_(p_dy), + p_x_(p_x), + p_gamma_(p_gamma), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dx_(p_dx) + { + lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + dxStrides_ = shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_); + dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_); + + isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize; + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DXDataType* p_dx_; + + std::vector lengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector dxStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // tensor descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + GridDesc_M_K dx_grid_desc_m_k_; + + bool isSweeponce_; + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + auto KernelSelector(bool isSweepOnce) + { + return isSweepOnce + ? kernel_normalization_bwd_data + : kernel_normalization_bwd_data; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = KernelSelector(arg.isSweeponce_); + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dx_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_gamma_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dx_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dyStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->xStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->gammaStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->meanStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->invStdStrides_); + + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dxStrides_); + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) override + { + if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + gammaStrides.size() != Rank || meanStrides.size() != Rank || + invStdStrides.size() != Rank || dxStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + dyStrides, + xStrides, + gammaStrides, + meanStrides, + invStdStrides, + dxStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dx)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize; + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c35652bfe860aeb56f4d5621aed2a312c78282d6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M dgamma_grid_desc_m, + const GridDesc_M dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) +{ + GridwiseReduction::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dgamma_grid_desc_m, + dbeta_grid_desc_m, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_mean_global, + p_inv_std_global, + p_dgamma_global, + p_dbeta_global); +}; + +template +struct DeviceNormalizationBwdGammaBetaImpl + : public DeviceNormalizationBwdGammaBeta +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert( + ((MThreadSliceSize % DGammaDstVectorSize == 0) || + (MThreadSliceSize % DBetaDstVectorSize == 0)), + "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " + "check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int numBlockTileIteration) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor(inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_grid_desc_m_k_padded; + } + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1})); + + using GridwiseNormalizationBwdGammaBeta = + GridwiseNormalizationBwdGammaBeta_mk_to_k; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DGammaDataType* p_dgamma, + DBetaDataType* p_dbeta) + : p_dy_(p_dy), + p_x_(p_x), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dgamma_(p_dgamma), + p_dbeta_(p_dbeta), + outLengths_{outLengths}, + dgammaStrides_{dgammaStrides}, + dbetaStrides_{dbetaStrides} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(inLengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, xStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, invStdStrides_, numBlockTileIteration_); + + dgamma_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dgammaStrides_); + dbeta_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dbetaStrides_); + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DGammaDataType* p_dgamma_; + DBetaDataType* p_dbeta_; + + std::vector inLengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector outLengths_; + std::vector dgammaStrides_; + std::vector dbetaStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // Source descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + + // Destination descriptor + GridDesc_M dgamma_grid_desc_m_; + GridDesc_M dbeta_grid_desc_m_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = + kernel_normalization_bwd_gamma_beta; + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dgamma_grid_desc_m_, + arg.dbeta_grid_desc_m_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dgamma_, + arg.p_dbeta_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsSrcVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + template + bool IsDstVectorSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(DstVectorSize == 1) + return true; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % DstVectorSize != 0) + return false; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->dyStrides_); + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->xStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->meanStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->invStdStrides_); + + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dgammaStrides_); + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dbetaStrides_); + + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) override + { + if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + meanStrides.size() != Rank || invStdStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim || + dbetaStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(inLengths, + dyStrides, + xStrides, + meanStrides, + invStdStrides, + outLengths, + dgammaStrides, + dbetaStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dgamma), + static_cast(p_dbeta)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ; + str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa7b9cbda02334449d2af14561262cdf51f69356 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invariant length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd +{ + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int numBlockTileIteration) + { + static constexpr index_t numSrcDim = Rank; + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); + y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); + save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides); + save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides); + + isSweeponce_ = + x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + SaveMeanInvStdDataType* p_saveMean_; + SaveMeanInvStdDataType* p_saveInvStd_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; + + YElementwiseOperation y_elementwise_op_; + + int numBlockTileIteration_; + size_t gridSize_; + + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K beta_grid_desc_m_k_; + GridDesc_M_K y_grid_desc_m_k_; + GridDesc_M save_mean_grid_desc_m_; + GridDesc_M save_inv_std_grid_desc_m_; + bool isSweeponce_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + + index_t invariant_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto kernel_main = NormalizationKernelSelector(arg.isSweeponce_); + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if constexpr(XYSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return (false); + } + + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvStd, + YElementwiseOperation y_elementwise_op) override + { + if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank || + betaStrides.size() != Rank || yStrides.size() != Rank || + saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + saveMeanStrides, + saveInvStdStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationFwdImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7fe6502bab454ae38359873d8c83992a2f16dd8b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp @@ -0,0 +1,750 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + WorkspaceMeanVarDataType* const __restrict__ p_welford_mean, + WorkspaceMeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count) +{ + GridwiseWelford::Run(x_grid_desc_m_k, + mean_var_grid_desc_m_kblock, + num_k_block_tile_iteration, + p_x_global, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +__global__ void +kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const WorkspaceMeanVarDataType* const p_mean_global, + const WorkspaceMeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, + count_grid_desc_m_kblock, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, + num_k_mean_var_count_iteration, + num_k_block_tile_iteration, + k_grid_size, + epsilon, + p_mean_global, + p_variance_global, + p_welford_count_global, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + p_save_mean_global, + p_save_inv_std_global, + y_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invariant length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd +{ + using WorkspaceMeanVarDataType = SaveMeanInvStdDataType; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int kBlockSize, + int numBlockTileIteration) + { + static constexpr index_t numSrcDim = Rank; + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + template + static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + using Kernel1MeanVarGridDesc_M_KBlock = + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2MeanVarGridDesc_M_KBlock = + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2CountGridDesc_M_KBlock = + decltype(MakeWorkspaceCountDescriptor_M_K, 1, 1>(1, 1)); + + using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); + + using GridwiseWelford = GridwiseNormalizationSplitK1st; + + using GridwiseWelfordNormalization = + GridwiseNormalizationSplitK2nd; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = 1; + while(true) + { + int testKGridSize = + math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + + // we want the kGridSize_ be not more than 128 + if(testKGridSize <= 128) + break; + + ++numBlockTileIteration_; + }; + + kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; + + // We do not use vector load for mean, var and count + static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize; + + numMeanVarCountIteration_ = + math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); + + save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides); + save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides); + + // We don't need to pad in K dimension for Welford1. Set KPerTile 1. + kernel1_mean_var_grid_desc_m_kblock_ = + MakeWorkspaceMeanVarDescriptor_M_K, M_BlockTileSize, 1>( + MRaw_, kGridSize_); + + kernel2_mean_var_grid_desc_m_kblock_ = + MakeWorkspaceMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + kernel2_count_grid_desc_m_kblock_ = + MakeWorkspaceCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + SaveMeanInvStdDataType* p_saveMean_; + SaveMeanInvStdDataType* p_saveInvStd_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; + + YElementwiseOperation y_elementwise_op_; + + int kGridSize_; + int numMeanVarCountIteration_; + int numBlockTileIteration_; + size_t gridSize_; + + SrcGridDesc_M_K x_grid_desc_m_k_; + SrcGridDesc_M_K gamma_grid_desc_m_k_; + SrcGridDesc_M_K beta_grid_desc_m_k_; + SrcGridDesc_M_K y_grid_desc_m_k_; + SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_; + SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_; + + Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; + Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; + Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + + index_t invariant_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr || + arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + auto kernel1 = kernel_normalizationSplitK1st; + + auto kernel2 = kernel_normalizationSplitK2nd; + + float avg_time = 0; + avg_time += launch_and_time_kernel( + stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); + + return avg_time; + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // workspace for welford intermediate mean + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if constexpr(XYVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return false; + } + + if(p_arg_->kGridSize_ <= 1) + return false; + + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvStd, + YElementwiseOperation y_elementwise_op) override + { + if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank || + betaStrides.size() != Rank || yStrides.size() != Rank || + saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + saveMeanStrides, + saveInvStdStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationFwdSplitKImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp deleted file mode 100644 index ea0d805043e07fe3d818fbb7bf7ef42e952496a9..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ /dev/null @@ -1,402 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/reduction_operator.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/device_reduce.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// Y = Normalization(X, Beta, Gamma) -// M: Invarient length -// K: Reduce length (Calculate mean and variance along K dimension) -// eg. Length = [N, C, H, W], reduce dim = [C, H, W] -// Then, M = N, K = C * H * W -template -struct DeviceNormalizationImpl : public DeviceNormalization -{ - static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); - static_assert( - ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || - (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), - "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); - - static_assert( - ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || - (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), - "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); - - using PassThrough = tensor_operation::element_wise::PassThrough; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - static auto MakeSrc2dDescriptor(const std::vector& inLengths, - const std::vector& inStrides, - int numBlockTileIteration) - { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); - - const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); - - const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - - const auto in_grid_desc_m_k = [&]() { - if constexpr(reduceAllDim) - { - const auto one_dim_inDesc = transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - return transform_tensor_descriptor(one_dim_inDesc, - make_tuple(make_unmerge_transform(make_tuple( - 1, one_dim_inDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - } - else - { - using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; - using ReduceDims = typename arithmetic_sequence_gen::type; - - const auto reduceDimLengths = - make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); - - return transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(reduceDimLengths)), - make_tuple(InvariantDims{}, ReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - }(); - - const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); - const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const auto inPad_M = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; - - auto in_grid_desc_m_k_padded = transform_tensor_descriptor( - in_grid_desc_m_k, - make_tuple(make_right_pad_transform(invariantLength, inPad_M), - make_right_pad_transform(reduceLength, inPad_K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return (in_grid_desc_m_k_padded); - }; - - using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); - - struct Argument : public BaseArgument - { - Argument(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector yStrides, - const std::vector reduceDims, - YElementwiseOperation y_elementwise_op, - double epsilon, - const XDataType* p_x, - const GammaDataType* p_gamma, - const BetaDataType* p_beta, - YDataType* p_y) - : p_x_(p_x), - p_gamma_(p_gamma), - p_beta_(p_beta), - p_y_(p_y), - y_elementwise_op_(y_elementwise_op) - { - epsilon_ = static_cast(epsilon); - - Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); - xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); - yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); - gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); - betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); - - long_index_t invariant_length; - long_index_t reduce_length; - - std::tie(invariant_length, reduce_length) = - get_2d_lengths(Lengths_); - - numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); - - gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); - - x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); - gamma_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_); - beta_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); - y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); - - isSweeponce_ = - x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; - } - - ComputeDataType epsilon_; - - const XDataType* p_x_; - const GammaDataType* p_gamma_; - const BetaDataType* p_beta_; - YDataType* p_y_; - - std::vector Lengths_; - std::vector xStrides_; - std::vector gammaStrides_; - std::vector betaStrides_; - std::vector yStrides_; - - YElementwiseOperation y_elementwise_op_; - - int numBlockTileIteration_; - size_t gridSize_; - - GridDesc_M_K x_grid_desc_m_k_; - GridDesc_M_K gamma_grid_desc_m_k_; - GridDesc_M_K beta_grid_desc_m_k_; - GridDesc_M_K y_grid_desc_m_k_; - bool isSweeponce_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - auto kernel_main = NormalizationKernelSelector(arg.isSweeponce_); - - float avg_time = 0; - avg_time += launch_and_time_kernel(stream_config, - kernel_main, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - arg.gamma_grid_desc_m_k_, - arg.beta_grid_desc_m_k_, - arg.y_grid_desc_m_k_, - arg.numBlockTileIteration_, - arg.epsilon_, - arg.p_x_, - arg.p_gamma_, - arg.p_beta_, - arg.p_y_, - arg.y_elementwise_op_); - - return (avg_time); - }; - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - }; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* p_arg_ = dynamic_cast(p_arg); - - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - - if constexpr(XYSrcVectorDim == 0) - { - if constexpr(NumInvariantDim == 0) - { - return false; - } - else - { - if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) - return false; - - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) - return false; - - if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) - return false; - }; - } - else - { - if(p_arg_->xStrides_[Rank - 1] != 1) - return false; - - if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) - return false; - - if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) - { - return false; - } - }; - - // if fastest dim is not reduced - if constexpr(GammaSrcVectorDim == 0) - { - if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) - return (false); - - if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) - return (false); - } - else // if fastest dim is reduced - { - if(p_arg_->gammaStrides_[Rank - 1] != 1) - return (false); - - if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) - return (false); - } - - // if fastest dim is not reduced - if constexpr(BetaSrcVectorDim == 0) - { - if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) - return (false); - - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) - return (false); - } - else // if fastest dim is reduced - { - if(p_arg_->betaStrides_[Rank - 1] != 1) - return (false); - - if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) - return (false); - } - - return true; - }; - - std::unique_ptr - MakeArgumentPointer(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector yStrides, - const std::vector reduceDims, - double epsilon, - const void* p_x, - const void* p_gamma, - const void* p_beta, - void* p_y, - void* p_saveMean, - void* p_saveInvVar, - YElementwiseOperation y_elementwise_op) override - { - // TODO - // Optional cache of the intermediate results (mean and InvVariance) during the - // forward pass could speedup in the backward - ignore = p_saveMean; - ignore = p_saveInvVar; - - return std::make_unique(lengths, - xStrides, - gammaStrides, - betaStrides, - yStrides, - reduceDims, - y_elementwise_op, - epsilon, - static_cast(p_x), - static_cast(p_gamma), - static_cast(p_beta), - static_cast(p_y)); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceNormalizationImpl<" << BlockSize << ","; - str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; - str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; - str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; - str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp deleted file mode 100644 index 8b2b3c41bfdf3f4793263dd65fc5f6dc941eb13b..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp +++ /dev/null @@ -1,658 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/reduction_operator.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" -#include "ck/tensor_operation/gpu/device/device_reduce.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" -#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -template -__global__ void -kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, - const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, - index_t num_k_block_tile_iteration, - const XDataType* const __restrict__ p_x_global, - MeanVarDataType* const __restrict__ p_welford_mean, - MeanVarDataType* const __restrict__ p_welford_variance, - int32_t* const __restrict__ p_welford_count) -{ - GridwiseWelford::Run(x_grid_desc_m_k, - mean_var_grid_desc_m_kblock, - num_k_block_tile_iteration, - p_x_global, - p_welford_mean, - p_welford_variance, - p_welford_count); -}; - -template -__global__ void -kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, - const CountGridDesc_M_KBlock count_grid_desc_m_kblock, - const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, - const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, - const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, - const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, - index_t num_k_mean_var_count_iteration, - index_t num_k_block_tile_iteration, - index_t k_grid_size, - ComputeDataType epsilon, - const MeanVarDataType* const p_mean_global, - const MeanVarDataType* const p_variance_global, - const int32_t* const p_welford_count_global, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const YElementwiseOperation y_elementwise_op) -{ - GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, - count_grid_desc_m_kblock, - x_grid_desc_m_k, - gamma_grid_desc_m_k, - beta_grid_desc_m_k, - y_grid_desc_m_k, - num_k_mean_var_count_iteration, - num_k_block_tile_iteration, - k_grid_size, - epsilon, - p_mean_global, - p_variance_global, - p_welford_count_global, - p_x_global, - p_gamma_global, - p_beta_global, - p_y_global, - y_elementwise_op); -}; -} // namespace ck - -namespace ck { -namespace tensor_operation { -namespace device { - -// Y = Normalization(X, Beta, Gamma) -// M: Invarient length -// K: Reduce length (Calculate mean and variance along K dimension) -// eg. Length = [N, C, H, W], reduce dim = [C, H, W] -// Then, M = N, K = C * H * W -template -struct DeviceNormalizationSplitKImpl : public DeviceNormalization -{ - using MeanVarDataType = ComputeDataType; - - static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); - static_assert( - ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || - (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), - "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); - - static_assert( - ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || - (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), - "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); - - using PassThrough = tensor_operation::element_wise::PassThrough; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - static auto MakeSrc2dDescriptor(const std::vector& inLengths, - const std::vector& inStrides, - int kBlockSize, - int numBlockTileIteration) - { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); - - const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); - - const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - - const auto in_grid_desc_m_k = [&]() { - if constexpr(reduceAllDim) - { - const auto one_dim_inDesc = transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - return transform_tensor_descriptor(one_dim_inDesc, - make_tuple(make_unmerge_transform(make_tuple( - 1, one_dim_inDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - } - else - { - using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; - using ReduceDims = typename arithmetic_sequence_gen::type; - - const auto reduceDimLengths = - make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); - - return transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(reduceDimLengths)), - make_tuple(InvariantDims{}, ReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - }(); - - const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); - const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; - const auto inPad_M = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength; - - auto in_grid_desc_m_k_padded = transform_tensor_descriptor( - in_grid_desc_m_k, - make_tuple(make_right_pad_transform(invariantLength, inPad_M), - make_right_pad_transform(reduceLength, inPad_K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return (in_grid_desc_m_k_padded); - }; - - template - static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) - { - const auto grid_desc_m_k = - make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); - return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); - } - - template - static auto MakeCountDescriptor_M_K(index_t M, index_t K) - { - const auto grid_desc_m_k = - make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); - return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); - } - - using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); - using Kernel1MeanVarGridDesc_M_KBlock = - decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); - - using Kernel2MeanVarGridDesc_M_KBlock = - decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); - - using Kernel2CountGridDesc_M_KBlock = - decltype(MakeCountDescriptor_M_K, 1, 1>(1, 1)); - - using GridwiseWelford = GridwiseNormalizationSplitK1st; - - using GridwiseWelfordNormalization = - GridwiseNormalizationSplitK2nd; - - struct Argument : public BaseArgument - { - Argument(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector yStrides, - const std::vector reduceDims, - YElementwiseOperation y_elementwise_op, - double epsilon, - const XDataType* p_x, - const GammaDataType* p_gamma, - const BetaDataType* p_beta, - YDataType* p_y) - : p_x_(p_x), - p_gamma_(p_gamma), - p_beta_(p_beta), - p_y_(p_y), - p_workspace_mean_{nullptr}, - p_workspace_var_{nullptr}, - p_workspace_count_{nullptr}, - y_elementwise_op_(y_elementwise_op) - { - epsilon_ = static_cast(epsilon); - - Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); - xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); - yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); - gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); - betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); - - std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); - - numBlockTileIteration_ = 1; - while(true) - { - int testKGridSize = - math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); - - // we want the kGridSize_ be not more than 128 - if(testKGridSize <= 128) - break; - - ++numBlockTileIteration_; - }; - - kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); - gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; - - // We do not use vector load for mean, var and count - static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize; - - numMeanVarCountIteration_ = - math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize); - - x_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); - gamma_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_); - beta_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_); - y_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); - - // We don't need to pad in K dimension for Welford1. Set KPerTile 1. - kernel1_mean_var_grid_desc_m_kblock_ = - MakeMeanVarDescriptor_M_K, M_BlockTileSize, 1>(MRaw_, - kGridSize_); - - kernel2_mean_var_grid_desc_m_kblock_ = - MakeMeanVarDescriptor_M_K, - M_BlockTileSize, - K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); - - kernel2_count_grid_desc_m_kblock_ = - MakeCountDescriptor_M_K, - M_BlockTileSize, - K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); - } - - ComputeDataType epsilon_; - - const XDataType* p_x_; - const GammaDataType* p_gamma_; - const BetaDataType* p_beta_; - YDataType* p_y_; - void* p_workspace_mean_; - void* p_workspace_var_; - void* p_workspace_count_; - - std::vector Lengths_; - std::vector xStrides_; - std::vector gammaStrides_; - std::vector betaStrides_; - std::vector yStrides_; - - YElementwiseOperation y_elementwise_op_; - - int kGridSize_; - int numMeanVarCountIteration_; - int numBlockTileIteration_; - size_t gridSize_; - - SrcGridDesc_M_K x_grid_desc_m_k_; - SrcGridDesc_M_K gamma_grid_desc_m_k_; - SrcGridDesc_M_K beta_grid_desc_m_k_; - SrcGridDesc_M_K y_grid_desc_m_k_; - - Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; - Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; - Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; - - index_t MRaw_; // invarient length - index_t KRaw_; // reduce length - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr || - arg.p_workspace_count_ == nullptr) - throw std::runtime_error("wrong! WorkSpace pointer has not been set"); - - auto kernel1 = kernel_normalizationSplitK1st; - - auto kernel2 = kernel_normalizationSplitK2nd; - - float avg_time = 0; - avg_time += launch_and_time_kernel(stream_config, - kernel1, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - arg.kernel1_mean_var_grid_desc_m_kblock_, - arg.numBlockTileIteration_, - arg.p_x_, - static_cast(arg.p_workspace_mean_), - static_cast(arg.p_workspace_var_), - static_cast(arg.p_workspace_count_)); - - avg_time += launch_and_time_kernel(stream_config, - kernel2, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.kernel2_mean_var_grid_desc_m_kblock_, - arg.kernel2_count_grid_desc_m_kblock_, - arg.x_grid_desc_m_k_, - arg.gamma_grid_desc_m_k_, - arg.beta_grid_desc_m_k_, - arg.y_grid_desc_m_k_, - arg.numMeanVarCountIteration_, - arg.numBlockTileIteration_, - arg.kGridSize_, - arg.epsilon_, - static_cast(arg.p_workspace_mean_), - static_cast(arg.p_workspace_var_), - static_cast(arg.p_workspace_count_), - arg.p_x_, - arg.p_gamma_, - arg.p_beta_, - arg.p_y_, - arg.y_elementwise_op_); - - return avg_time; - }; - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - }; - }; - - size_t GetWorkSpaceSize(const BaseArgument* pArg) const override - { - const Argument* pArg_ = dynamic_cast(pArg); - - size_t workspace_size = 0; - - int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; - - // workspace for welford intermediate mean - workspace_size += welford_size * sizeof(MeanVarDataType) + 64; - - // workspace for welford intermediate variance - workspace_size += welford_size * sizeof(MeanVarDataType) + 64; - - // workspace for welford intermediate count - workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; - - return (workspace_size); - }; - - void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override - { - Argument* pArg_ = dynamic_cast(pArg); - - pArg_->p_workspace_ = p_workspace; - - int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; - - // setup buffer used for intermediate welford mean - pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); - - index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); - mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); - - // setup buffer used for intermediate welford varirance - pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; - - index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); - variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); - - // setup buffer used for intermediate welford count - pArg_->p_workspace_count_ = - reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* p_arg_ = dynamic_cast(p_arg); - - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - - if constexpr(XYVectorDim == 0) - { - if constexpr(NumInvariantDim == 0) - { - return false; - } - else - { - if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) - return false; - - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) - return false; - - if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) - return false; - }; - } - else - { - if(p_arg_->xStrides_[Rank - 1] != 1) - return false; - - if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) - return false; - - if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) - return false; - }; - - // if fastest dim is not reduced - if constexpr(GammaSrcVectorDim == 0) - { - if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) - return false; - - if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) - return false; - } - else // if fastest dim is reduced - { - if(p_arg_->gammaStrides_[Rank - 1] != 1) - return false; - - if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) - return false; - } - - // if fastest dim is not reduced - if constexpr(BetaSrcVectorDim == 0) - { - if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) - return false; - - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) - return false; - } - else // if fastest dim is reduced - { - if(p_arg_->betaStrides_[Rank - 1] != 1) - return false; - - if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) - return false; - } - - if(p_arg_->kGridSize_ <= 1) - return false; - - return true; - }; - - std::unique_ptr - MakeArgumentPointer(const std::vector lengths, - const std::vector xStrides, - const std::vector gammaStrides, - const std::vector betaStrides, - const std::vector yStrides, - const std::vector reduceDims, - double epsilon, - const void* p_x, - const void* p_gamma, - const void* p_beta, - void* p_y, - void* p_saveMean, - void* p_saveInvVar, - YElementwiseOperation y_elementwise_op) override - { - // TODO - // Optional cache of the intermediate results (mean and InvVariance) during the - // forward pass could speedup in the backward - ignore = p_saveMean; - ignore = p_saveInvVar; - - return std::make_unique(lengths, - xStrides, - gammaStrides, - betaStrides, - yStrides, - reduceDims, - y_elementwise_op, - epsilon, - static_cast(p_x), - static_cast(p_gamma), - static_cast(p_beta), - static_cast(p_y)); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceNormalizationSplitKImpl<" << BlockSize << ","; - str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; - str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; - str << "XYSrcVectorDim_" << XYVectorDim << ","; - str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp index c94c568c49aca819997cd8a1d32c536d7527a551..fcb9b8a90346fd4ea62a5b7a21929fd822b0942e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp @@ -1,9 +1,20 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -16,95 +27,359 @@ template -struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC +struct DevicePool2dFwd_NHWC_NHWC : public DevicePoolFwd<4, + 2, + InDataType, + OutDataType, + IndexDataType, + tensor_layout::convolution::NHWC, + tensor_layout::convolution::NHWC, + ReduceOpId, + OutputIndex> { - using DevicePool3D = DevicePool3dFwd_NDHWC_NDHWC{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t InOutRank = 4; + static constexpr index_t WindowRank = 2; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(std::vector input_nchw_lengths, + std::vector output_nchw_lengths, + std::vector input_nchw_stride, + std::vector output_nchw_stride, + std::vector window_spatial_yx_lengths, + std::vector window_yx_strides, + std::vector window_yx_dilations, + std::vector input_left_hw_pads, + std::vector input_right_hw_pads) + { + const index_t N = input_nchw_lengths[0]; + const index_t C = input_nchw_lengths[1]; + const index_t Hi = input_nchw_lengths[2]; + const index_t Wi = input_nchw_lengths[3]; + + const index_t Ho = output_nchw_lengths[2]; + const index_t Wo = output_nchw_lengths[3]; + const index_t Y = window_spatial_yx_lengths[0]; + const index_t X = window_spatial_yx_lengths[1]; + + const index_t WindowStrideH = window_yx_strides[0]; + const index_t WindowStrideW = window_yx_strides[1]; + + const index_t WindowDilationH = window_yx_dilations[0]; + const index_t WindowDilationW = window_yx_dilations[1]; + + const index_t InLeftPadH = input_left_hw_pads[0]; + const index_t InLeftPadW = input_left_hw_pads[1]; + + const index_t InRightPadH = input_right_hw_pads[0]; + const index_t InRightPadW = input_right_hw_pads[1]; + + const index_t MRaw = N * Ho * Wo * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = Y * X; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // A[ReduceM, ReduceK] + const index_t Ni_stride = input_nchw_stride[0]; + const index_t Ci_stride = input_nchw_stride[1]; + const index_t Hi_stride = input_nchw_stride[2]; + const index_t Wi_stride = input_nchw_stride[3]; + + const auto in_grid_desc_n_hi_wi_c = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(Ni_stride, Hi_stride, Wi_stride, Ci_stride)); + + const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_hip_wip_c, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_grid_desc_reducemraw_reducekraw = + transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), + make_merge_transform(make_tuple(Y, X))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const index_t No_stride = output_nchw_stride[0]; + const index_t Co_stride = output_nchw_stride[1]; + const index_t Ho_stride = output_nchw_stride[2]; + const index_t Wo_stride = output_nchw_stride[3]; + + const auto out_grid_desc_n_ho_wo_c = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(No_stride, Ho_stride, Wo_stride, Co_stride)); + + const auto out_grid_desc_reducemraw = + transform_tensor_descriptor(out_grid_desc_n_ho_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto out_grid_desc_reducem = + transform_tensor_descriptor(out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = + decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + IndexDataType* p_out_indices_dev, + std::vector& input_nchw_lengths, + std::vector& output_nchw_lengths, + std::vector& input_nchw_stride, + std::vector& output_nchw_stride, + std::vector&, // indices_nchw_stride + std::vector& window_spatial_yx_lengths, + std::vector& window_yx_strides, + std::vector& window_yx_dilations, + std::vector& input_left_hw_pads, + std::vector& input_right_hw_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{}, + input_nchw_lengths_{input_nchw_lengths}, + output_nchw_lengths_{output_nchw_lengths}, + input_nchw_stride_{input_nchw_stride}, + output_nchw_stride_{output_nchw_stride} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_nchw_lengths, + output_nchw_lengths, + input_nchw_stride, + output_nchw_stride, + window_spatial_yx_lengths, + window_yx_strides, + window_yx_dilations, + input_left_hw_pads, + input_right_hw_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + int32_t reduceLength = window_spatial_yx_lengths[0] * window_spatial_yx_lengths[1]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + IndexDataType* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + std::vector input_nchw_lengths_; + std::vector output_nchw_lengths_; + std::vector input_nchw_stride_; + std::vector output_nchw_stride_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + // for NHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K + + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; - std::unique_ptr + const auto kernel = + kernel_reduce_threadwise; + + ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (M / M_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + // C should be fastest dimension + if(pArg->input_nchw_stride_[1] != 1) + return false; + + for(int i = 0; i < InOutRank; ++i) + { + if(pArg->input_nchw_stride_[i] == 1 && + pArg->input_nchw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + + if(pArg->output_nchw_stride_[i] == 1 && + pArg->output_nchw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + virtual std::unique_ptr MakeArgumentPointer(const void* p_in_dev, void* p_out_dev, void* p_out_indices_dev, - std::vector input_lengths, - std::vector window_lengths, - std::vector output_lengths, - std::vector input_stride, - std::vector output_stride, - std::vector indices_stride, - std::vector window_strides, - std::vector window_dilations, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector input_nchw_lengths, + std::vector window_yx_lengths, + std::vector output_nchw_lengths, + std::vector input_nchw_stride, + std::vector output_nchw_stride, + std::vector indices_nchw_stride, + std::vector window_yx_strides, + std::vector window_yx_dilations, + std::vector input_left_hw_pads, + std::vector input_right_hw_pads, std::vector pooling_dims) override { - static constexpr index_t InOutRank = 4; - static constexpr index_t WindowRank = 2; - - if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || - input_lengths.size() != InOutRank || window_strides.size() != WindowRank || - window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank || - input_right_pads.size() != WindowRank) + if(input_nchw_lengths.size() != InOutRank || window_yx_lengths.size() != WindowRank || + input_nchw_lengths.size() != InOutRank || window_yx_strides.size() != WindowRank || + window_yx_dilations.size() != WindowRank || input_left_hw_pads.size() != WindowRank || + input_right_hw_pads.size() != WindowRank) throw std::runtime_error("dimension is incorrect"); if(pooling_dims != std::vector{2, 3}) throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); - // NCHW to NCDHW - input_lengths.insert(input_lengths.begin() + 2, 1); - output_lengths.insert(output_lengths.begin() + 2, 1); - input_stride.insert(input_stride.begin() + 2, 0); - output_stride.insert(output_stride.begin() + 2, 0); - indices_stride.insert(indices_stride.begin() + 2, 0); - - // YX to ZYX - window_lengths.insert(window_lengths.begin(), 1); - window_strides.insert(window_strides.begin(), 0); - window_dilations.insert(window_dilations.begin(), 0); - input_left_pads.insert(input_left_pads.begin(), 0); - input_right_pads.insert(input_right_pads.begin(), 0); - - pooling_dims = {2, 3, 4}; - - return DevicePool3D::MakeArgumentPointer(p_in_dev, - p_out_dev, - p_out_indices_dev, - input_lengths, - window_lengths, - output_lengths, - input_stride, - output_stride, - indices_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - pooling_dims); + if(output_nchw_stride != indices_nchw_stride) + throw std::runtime_error( + "output_nchw_stride need to be equal to indices_nchw_stride for now"); + + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + input_nchw_lengths, + output_nchw_lengths, + input_nchw_stride, + output_nchw_stride, + indices_nchw_stride, + window_yx_lengths, + window_yx_strides, + window_yx_dilations, + input_left_hw_pads, + input_right_hw_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool2dFwd_NHWC_NHWC<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp index 2481c5c76971db6bb2ddc390510089af5f656ae6..67956d9f3f0ee3d1ae43faae30c0eee2a74a2c20 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,7 +19,7 @@ namespace device { template std::pair get_2d_lengths(const std::vector& inLengths) { - static_assert(Rank <= 6, "bigger Rank size not supported!"); + static_assert(Rank <= 12, "bigger Rank size not supported!"); long_index_t invariant_total_length = 1; long_index_t reduce_total_length = 1; @@ -38,7 +38,7 @@ std::pair get_2d_lengths(const std::vector& template std::pair get_2d_lengths(const std::array& inLengths) { - static_assert(Rank <= 6, "bigger Rank size not supported!"); + static_assert(Rank <= 12, "bigger Rank size not supported!"); long_index_t invariant_total_length = 1; long_index_t reduce_total_length = 1; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp index bf3deeb57acbdff95995f5da096ab574140bcb42..b4873e3403ce08648613358f18af9889af534e6c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce { - static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(Rank <= 12, "Bigger Rank size is not supported!"); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, "Invalid thread cluster size assignments!"); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp index 6c5895b010659cafb98e4a78bc6587d0222e02bd..8291575fb82cf133e37a47e59cd759eaec95c12a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,6 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" namespace ck { @@ -48,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce { - static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(Rank <= 12, "Bigger Rank size is not supported!"); static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..764b9312f358bbd07a3bc7816668b8bf675a2bf4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" + +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD + +{ + static_assert(Rank <= 12, "Bigger Rank size is not supported!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto + MakeDsDescriptor(const std::array, NumDTensor> DsLengths, + std::array, NumDTensor> DsStrides) + { + return generate_tuple( + [&](auto i) { + return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(DsLengths[i], + DsStrides[i]); + }, + Number{}); + } + + using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {})); + using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {})); + using DsGridDesc_M = decltype(MakeDsDescriptor({}, {})); + + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise_multi_d; + + using DsGridPointer = typename GridwiseReduce::DsGridPointer; + + struct Argument : public BaseArgument + { + Argument(const std::array inLengths, + const std::array inStrides, + const std::array, NumDTensor> DsLengths, + const std::array, NumDTensor> DsStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + const InDataType* in_dev, + const std::array ds_dev, + OutDataType* out_dev, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op) + : DsLengths_{DsLengths}, + DsStrides_{DsStrides}, + outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + out_dev_{out_dev}, + in_elementwise_op_{in_elementwise_op}, + out_elementwise_op_{out_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(ds_dev[i]); + }); + + ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides); + } + + std::array inLengths_; + std::array inStrides_; + + std::array, NumDTensor> DsLengths_; + std::array, NumDTensor> DsStrides_; + + std::array outLengths_; + std::array outStrides_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + + DsGridPointer p_ds_grid_; + + InElementwiseOperation in_elementwise_op_; + OutElementwiseOperation out_elementwise_op_; + + DsGridDesc_M ds_grid_desc_m_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = + DeviceReduceThreadWiseMultiD::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); + const auto out_grid_desc_m = + DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + + float avg_time = 0; + + const auto kernel = kernel_reduce_threadwise_multi_d; + + avg_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + arg.ds_grid_desc_m_, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.out_elementwise_op_, + arg.in_dev_, + arg.p_ds_grid_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + std::cerr << "reduce_total_length = " << pArg->reduce_total_length + << " KThreadSliceSize = " << KThreadSliceSize << std::endl; + + // cases with big reduce_total_length should be handled by Blockwise kernel + if(pArg->reduce_total_length / KThreadSliceSize >= 32) + return (false); + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array, NumDTensor> DsLengths, + const std::array, NumDTensor> DsStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + const void* in_dev, + const std::array ds_dev, + void* out_dev, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + DsLengths, + DsStrides, + outLengths, + outStrides, + reduceDims, + static_cast(in_dev), + ds_dev, + static_cast(out_dev), + in_elementwise_op, + out_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index cfab7d29c9002f4fc61b154e127c4f09ab9005c9..609697cb97a465c8248823b83930f1daa416a7e3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index 2855535ede177d8f4ea35ea432b3567c315d5ff9..21a31c61fca586cbe385070cfac652f1ee15e7c3 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -55,7 +55,10 @@ struct MaskOutUpperTrianglePredicate template struct C0MatrixMask_impl { - constexpr C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} + constexpr __host__ __device__ C0MatrixMask_impl(index_t NRaw) + : NRaw_(NRaw), predicate_(MaskOutPredicate{}) + { + } __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const { diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index c66d2fc516babe3aa5a16de053d9ada89eabf8ef..02941531475ffe71cf76f96a7259facfaccbec00 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder +auto grid_desc(MatrixPadder matrix_padder, + CDesc_MRaw_NRaw conv_desc) +{ + auto res = matrix_padder.PadCDescriptor_M_N(conv_desc); + return res; +} // M/N/KPerTileType could be index_t or Number<> template (x1); }; + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const @@ -68,6 +75,15 @@ struct Add y = ck::type_convert(y_tmp); } + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } + template <> __host__ __device__ constexpr void operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const @@ -76,13 +92,130 @@ struct Add }; }; -struct ScaleAdd +struct Max +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::max(x0_converted, x1_converted); + } +}; + +struct Min { - __host__ __device__ ScaleAdd(float scale) : scale_(scale) {} + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::min(x0_converted, x1_converted); + } +}; +struct Multiply +{ template __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 * type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 * x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 * x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 * x1; + }; +}; + +struct ScaleAdd +{ + __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} + + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const + { + y = ck::type_convert(scale_ * ck::type_convert(x0) + ck::type_convert(x1)); + } + template <> __host__ __device__ void operator()(float& y, const float& x0, const half_t& x1) const @@ -146,7 +279,7 @@ struct Subtract struct Bilinear { - Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; @@ -165,6 +298,14 @@ struct Bilinear y = alpha_ * x0 + beta_ * x1; }; + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const half_t& x0, const half_t& x1) const @@ -179,6 +320,33 @@ struct Bilinear y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x0_tmp = type_convert(x0); + const float x1_tmp = type_convert(x1); + const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp; + y = type_convert(y_tmp); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + const float y_tmp = alpha_ * x0 + beta_ * x1_tmp; + y = y_tmp; + }; + + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + float alpha_; float beta_; }; @@ -228,6 +396,14 @@ struct AddRelu y = a > 0.0f ? a : 0.0f; }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + template <> __host__ __device__ constexpr void operator()(int& y, const int& x0, const int8_t& x1) const @@ -318,6 +494,174 @@ struct AddFastGelu e = type_convert(x1_f); } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const + { + const float x0_f = type_convert(c) + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +// E = MultiplyFastGelu(C + D) +struct MultiplyFastGelu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c * d; + + FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c * d; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c * d; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const + { + const float x0_f = type_convert(c) * type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c * type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +// E = Silu(C + D) +struct AddSilu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +struct ConvScaleAdd +{ + __host__ __device__ ConvScaleAdd(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ void + operator()(f8_t& e, const float& c, const float& d) const + { + float x; + Add{}.template operator()(x, c * scale_in_ * scale_wei_, d); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3cc1c3c42ccc260f463f57968092d77ee77c73ca --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// y = UnaryOp0(UnaryOp1(...(x))) +template +struct UnaryCombinedOp +{ + __host__ __device__ UnaryCombinedOp() : unary_ops_() {} + + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // Execute first unary op to copy data to y + unary_ops_.At(Number<0>{})(y, x); + + static_for<1, Tuple::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); }); + }; + + Tuple unary_ops_; +}; + +// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1)) +template +struct BinaryWithUnaryCombinedOp +{ + __host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {} + + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1) + : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result); + }; + + private: + BinaryOp binary_op_; + UnaryOp0 unary_op0_; + UnaryOp1 unary_op1_; +}; + +// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2)) +template +struct TrinaryWithUnaryCombinedOp +{ + __host__ __device__ TrinaryWithUnaryCombinedOp() + : binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_() + { + } + + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, + BinaryOp0 binary_op1, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1, + UnaryOp2 unary_op2) + : binary_op0_(binary_op0), + binary_op1_(binary_op1), + unary_op0_(unary_op0), + unary_op1_(unary_op1), + unary_op2_(unary_op2) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const + { + + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + Y unary_x2_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + unary_op2_(unary_x2_tmp_result, x2); + binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result); + binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result); + }; + + private: + BinaryOp0 binary_op0_{}; + BinaryOp1 binary_op1_{}; + UnaryOp0 unary_op0_{}; + UnaryOp1 unary_op1_{}; + UnaryOp2 unary_op2_{}; +}; + +using ScaleScalePass = UnaryCombinedOp; +using ScaleScaleRelu = UnaryCombinedOp; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 3fdb391a0215e83328905b77148b675defc27af7..b914c0b96f7041d44c8ff0761e868c6e7b581658 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -195,6 +195,125 @@ struct AddMultiply } }; +// C = A * B +// E = C x D0 + D1 +struct MultiplyAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (c * d0) + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(bhalf_t& e, + const float& c, + const bhalf_t& d0, + const bhalf_t& d1) const + { + const bhalf_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(float& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const float y = c * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const float& d0, + const float& d1) const + { + const float y = c * d0 + d1; + e = y; + } +}; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +struct MultiplyAddFastGelu +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const + { + const float x0_f = c * ck::type_convert(d0) + ck::type_convert(d1); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = ck::type_convert(x1_f); + } +}; + // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { @@ -266,6 +385,71 @@ struct AddAddFastGelu } }; +// E = Relu(alpha1 * C + alpha2 * D0 + D1) +struct ScaleAddScaleAddRelu +{ + + ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f) + : alpha1_(alpha1), alpha2_(alpha2) + { + } + + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x = c * alpha1_ + alpha2_ * d0 + d1; + e = x > 0 ? x : 0; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * type_convert(d0) + + type_convert(d1); + + float result = 0; + result = x > 0 ? x : 0; + + e = type_convert(result); + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * type_convert(d0) + + type_convert(d1); + + float result = 0; + result = x > 0 ? x : 0; + + e = type_convert(result); + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& e, const int8_t& c, const float& d0, const float& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * d0 + d1; + + float result = 0; + result = x > 0 ? x : 0; + + e = type_convert(result); + } + + const float alpha1_; + const float alpha2_; +}; + struct Normalize { // FIXME: is double absolutely necessary? diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c47822499c870a3a1cd33a3bc11adf4fcf5ef65f..8870298adb57deb50f32a7ab576eb037a2b62b62 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,64 +7,131 @@ #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" #include "ck/utility/type_convert.hpp" +#include namespace ck { namespace tensor_operation { namespace element_wise { -#if CK_WORKAROUND_SWDEV_383542 -extern "C" __device__ float __ocml_native_recip_f32(float); -#endif +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnon-virtual-dtor" +struct UnaryOpBase +{ + public: + __host__ __device__ ~UnaryOpBase() = default; + + __host__ __device__ constexpr UnaryOpBase() = default; + __host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default; + __host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default; + __host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default; + __host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default; + + __host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0; + + __host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0; + + __host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0; + + __host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0; + + __host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0; + + __host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0; +}; + +struct PassThroughPack2 +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const + { + auto t = type_convert(x); + y = type_convert(t); + } + constexpr const static bool is_pack2_invocable = true; +}; -struct PassThrough +struct PassThrough final : public UnaryOpBase { + __host__ __device__ constexpr PassThrough() = default; + __host__ __device__ constexpr PassThrough(const PassThrough&) = default; + __host__ __device__ constexpr PassThrough(PassThrough&&) = default; + __host__ __device__ PassThrough& operator=(const PassThrough&) = default; + __host__ __device__ PassThrough& operator=(PassThrough&&) = default; + __host__ __device__ ~PassThrough() = default; + + __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; } + + __host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; } + template __host__ __device__ void operator()(Y& y, const X& x) const; template <> - __host__ __device__ void operator()(double& y, const double& x) const + __host__ __device__ void operator()(float& y, const double& x) const { - y = x; + y = type_convert(x); } template <> - __host__ __device__ void operator()(float& y, const float& x) const + __host__ __device__ void operator()(double& y, const float& x) const { - y = x; + y = type_convert(x); } template <> - __host__ __device__ void operator()(half_t& y, const half_t& x) const + __host__ __device__ void operator()(half_t& y, const float& x) const { - y = x; + y = type_convert(x); } template <> - __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + __host__ __device__ void operator()(bhalf_t& y, const float& x) const { - y = x; + y = type_convert(x); } template <> - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const + __host__ __device__ void operator()(float& y, const bhalf_t& x) const { - y = x; + y = type_convert(x); } template <> - __host__ __device__ void operator()(bhalf_t& y, const float& x) const + __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const { y = type_convert(x); } template <> - __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const + __host__ __device__ void operator()(float& y, const half_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const int8_t& x) const { y = type_convert(x); } template <> - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + __host__ __device__ void operator()(uint8_t& y, const uint8_t& x) const { y = x; } @@ -75,12 +142,35 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(int32_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const int8_t& x) const + { + y = type_convert(x); + } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> __host__ __device__ void operator()(int4_t& y, const int4_t& x) const { y = x; } + template <> + __host__ __device__ void operator()(int4_t& y, const int& x) const + { + y = type_convert(x); + } #endif template <> @@ -112,6 +202,36 @@ struct PassThrough { y = type_convert(x); } + + template <> + __host__ __device__ void operator()(bf8_t& y, const bf8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const half_t& x) const + { + y = ck::type_convert(x); + } }; struct UnaryConvert @@ -147,7 +267,8 @@ struct ConvertF8SR __host__ __device__ void operator()(Y& y, const X& x) const { // check Y datatype - static_assert(is_same::value, "Data type is not supported by this operation!"); + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); // check X datatype static_assert(is_same::value || is_same::value, @@ -157,12 +278,47 @@ struct ConvertF8SR } }; +struct ConvertF8RNE +{ + // convert to fp8 using rounding to nearest even + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_rne(x); + } +}; + struct Scale { - __host__ __device__ Scale(float scale) : scale_(scale) {} + __host__ __device__ Scale(float scale = 1.f) : scale_(scale) {} template - __host__ __device__ void operator()(Y& y, const X& x) const; + __host__ __device__ void operator()(Y& y, const X& x) const + { + y = ck::type_convert(ck::type_convert(x) * scale_); + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = ck::type_convert(scale_) * x; + }; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + const float x_tmp = ck::type_convert(x); + const float y_tmp = scale_ * x_tmp; + y = ck::type_convert(y_tmp); + }; template <> __host__ __device__ void operator()(float& y, const float& x) const @@ -176,6 +332,12 @@ struct Scale y = scale_ * x; }; + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = ck::type_convert(scale_ * ck::type_convert(x)); + }; + float scale_; }; @@ -203,12 +365,39 @@ struct UnaryDivide __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); y = x / type_convert(divider_); }; + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + float x_ = type_convert(x); + float divider_f_ = type_convert(divider_); + + y = type_convert(x_ / divider_f_); + }; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_ = type_convert(x); + float divider_f_ = type_convert(divider_); + + y = type_convert(x_ / divider_f_); + }; + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + float x_ = type_convert(x); + float divider_f_ = type_convert(divider_); + + y = type_convert(x_ / divider_f_); + }; + int32_t divider_ = 1; }; @@ -217,8 +406,8 @@ struct UnarySquare template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same_v || is_same_v || is_same_v || - is_same_v + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 || is_same_v #endif @@ -228,17 +417,48 @@ struct UnarySquare }; }; -struct UnaryAbs +struct UnaryAbs final : public UnaryOpBase { - template - __host__ __device__ void operator()(T& y, const T& x) const + __host__ __device__ constexpr UnaryAbs() = default; + __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; + __host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default; + __host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default; + __host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default; + __host__ __device__ ~UnaryAbs() = default; + + __host__ __device__ inline void operator()(float& y, const float& x) const final { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Data type is not supported by this operation!"); + y = ck::math::abs(x); + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + y = ck::math::abs(x); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + y = ck::math::abs(x); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + y = ck::math::abs(x); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + y = ck::math::abs(x); + } + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { y = ck::math::abs(x); + } + + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = ck::type_convert(ck::math::abs(ck::type_convert(x))); }; }; @@ -254,20 +474,41 @@ struct UnarySqrt }; }; -struct Relu +struct Relu final : public UnaryOpBase { - template - __host__ __device__ void operator()(T& y, const T& x) const + __host__ __device__ constexpr Relu() = default; + __host__ __device__ constexpr Relu(const Relu&) = default; + __host__ __device__ constexpr Relu(Relu&&) = default; + __host__ __device__ Relu& operator=(const Relu&) = default; + __host__ __device__ Relu& operator=(Relu&&) = default; + __host__ __device__ ~Relu() = default; + + __host__ __device__ inline void operator()(float& y, const float& x) const final { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Data type is not supported by this operation!"); y = x > 0 ? x : 0; } - template <> - __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + y = x > 0 ? x : 0; + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + y = x > 0 ? x : 0; + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + y = x > 0 ? x : 0; + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + y = x > 0 ? x : 0; + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { float x_f32 = ck::type_convert(x); float y_f32 = x_f32 > 0 ? x_f32 : 0; @@ -279,7 +520,7 @@ struct Relu // https://paperswithcode.com/method/gelu // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // host code use higher accuracy "exp" and "div" -// gpu code use lower accuracy "__expf" and "rcp" function +// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function struct FastGelu { template @@ -292,27 +533,25 @@ struct FastGelu template <> __host__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - - y = x * cdf; + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); } #endif - // device code, use lower precision "__expf" and "rcp" + // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = __expf(-u); - -#if !CK_WORKAROUND_SWDEV_383542 - const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); -#else - const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); -#endif + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __ocml_exp_f32(u); - y = x * cdf; + y = x * ck::math::rcp(1.f + emu); } template <> @@ -354,6 +593,46 @@ struct FastGelu y = type_convert(y_f); } + + template <> + __host__ void operator()(bhalf_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(bhalf_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } }; // https://paperswithcode.com/method/gelu @@ -377,48 +656,1148 @@ struct Gelu } }; -struct Sigmoid +struct Sigmoid final : public UnaryOpBase { - template - __host__ __device__ void operator()(T& y, const T& x) const + __host__ __device__ constexpr Sigmoid() = default; + __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; + __host__ __device__ constexpr Sigmoid(Sigmoid&&) = default; + __host__ __device__ Sigmoid& operator=(const Sigmoid&) = default; + __host__ __device__ Sigmoid& operator=(Sigmoid&&) = default; + __host__ __device__ ~Sigmoid() = default; + + __host__ __device__ inline void operator()(float& y, const float& x) const final { - static_assert(is_same::value || is_same::value || - is_same::value, - "Data type is not supported by this operation!"); + constexpr float one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + } - y = 1 / (ck::type_convert(1) + exp(-x)); - }; + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + constexpr double one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + constexpr int32_t one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + constexpr int8_t one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + constexpr half_t one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + constexpr float one = type_convert(1); + float x_f32 = ck::type_convert(x); + float y_f32 = one / (one + ck::math::exp(x_f32)); + y = ck::type_convert(y_f32); + } }; -struct TanH +struct Silu { template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same::value || is_same::value || - is_same::value, + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v, "Data type is not supported by this operation!"); - - y = ck::math::tanh(x); + constexpr T one = type_convert(1); + y = x * (one / (one + ck::math::exp(-x))); }; }; -struct Swish +struct TanH final : public UnaryOpBase { - Swish(float beta = 1.0f) : beta_(beta) {} - - template - __host__ __device__ void operator()(T& y, const T& x) const + __host__ __device__ constexpr TanH() = default; + __host__ __device__ constexpr TanH(const TanH&) = default; + __host__ __device__ constexpr TanH(TanH&&) = default; + __host__ __device__ TanH& operator=(const TanH&) = default; + __host__ __device__ TanH& operator=(TanH&&) = default; + __host__ __device__ ~TanH() = default; + + __host__ __device__ inline void operator()(float& y, const float& x) const final { - static_assert(is_same::value || is_same::value || - is_same::value, - "Data type is not supported by this operation!"); + y = ck::math::tanh(x); + } - y = x / (ck::type_convert(1) + ck::math::exp(-beta_ * x)); + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + y = ck::math::tanh(x); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + y = ck::math::tanh(x); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + y = ck::math::tanh(x); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + y = ck::math::tanh(x); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + y = ck::math::tanh(x); + } +}; + +struct ACos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acos(x); + }; +}; + +struct Neg +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::neg(x); + }; +}; + +struct ATan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atan(x); + }; +}; + +struct Sin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sin(x); + }; +}; + +struct ASinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asinh(x); + }; +}; + +struct Cos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cos(x); + }; +}; + +struct ACosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acosh(x); + }; +}; + +struct Tan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tan(x); + }; +}; + +struct ATanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atanh(x); + }; +}; + +struct SinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sinh(x); + }; +}; + +struct Ceil +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::ceil(x); + }; +}; + +struct Exp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::exp(x); + }; +}; + +struct CosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cosh(x); + }; +}; + +struct Floor +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::floor(x); + }; +}; + +struct Log +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::log(x); + }; +}; + +struct ASin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asin(x); + }; +}; + +struct Rcp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::rcp(x); + }; +}; + +struct Swish final : public UnaryOpBase +{ + __host__ __device__ constexpr Swish(const Swish&) = default; + __host__ __device__ constexpr Swish(Swish&&) = default; + __host__ __device__ ~Swish() = default; + + __host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {} + + __host__ __device__ float get_beta() const { return beta_; } + + const float beta_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + } +}; + +struct SoftRelu final : public UnaryOpBase +{ + __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; + __host__ __device__ constexpr SoftRelu(SoftRelu&&) = default; + __host__ __device__ ~SoftRelu() = default; + + __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {} + + __host__ __device__ float get_alpha() const { return alpha_; } + + const float alpha_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float casted_alpha = type_convert(alpha_); + constexpr float one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + double casted_alpha = type_convert(alpha_); + constexpr double one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + int32_t casted_alpha = type_convert(alpha_); + constexpr int32_t one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + int8_t casted_alpha = type_convert(alpha_); + constexpr int8_t one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + half_t casted_alpha = type_convert(alpha_); + constexpr half_t one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + bhalf_t casted_alpha = type_convert(alpha_); + constexpr bhalf_t one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } +}; + +struct Power final : public UnaryOpBase +{ + __host__ __device__ constexpr Power(const Power&) = default; + __host__ __device__ constexpr Power(Power&&) = default; + __host__ __device__ ~Power() = default; + + __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma) + { + } + + __host__ __device__ float get_alpha() const { return alpha_; } + + __host__ __device__ float get_beta() const { return beta_; } + + __host__ __device__ float get_gamma() const { return gamma_; } + + const float alpha_; + const float beta_; + const float gamma_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float casted_alpha = type_convert(alpha_); + float casted_beta = type_convert(beta_); + float casted_gamma = type_convert(gamma_); + + float shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + double casted_alpha = type_convert(alpha_); + double casted_beta = type_convert(beta_); + double casted_gamma = type_convert(gamma_); + + double shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + int32_t casted_alpha = type_convert(alpha_); + int32_t casted_beta = type_convert(beta_); + int32_t casted_gamma = type_convert(gamma_); + + int32_t shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + int8_t casted_alpha = type_convert(alpha_); + int8_t casted_beta = type_convert(beta_); + int8_t casted_gamma = type_convert(gamma_); + + int8_t shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + half_t casted_alpha = type_convert(alpha_); + half_t casted_beta = type_convert(beta_); + half_t casted_gamma = type_convert(gamma_); + + half_t shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + bhalf_t casted_alpha = type_convert(alpha_); + bhalf_t casted_beta = type_convert(beta_); + bhalf_t casted_gamma = type_convert(gamma_); + + bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } +}; + +struct ClippedRelu final : public UnaryOpBase +{ + __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; + __host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default; + __host__ __device__ ~ClippedRelu() = default; + + __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) + : alpha_(alpha), beta_(beta) + { + } + + __host__ __device__ float get_alpha() const { return alpha_; } + + __host__ __device__ float get_beta() const { return beta_; } + + const float alpha_; + const float beta_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float casted_alpha = type_convert(alpha_); + float casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + double casted_alpha = type_convert(alpha_); + double casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + int32_t casted_alpha = type_convert(alpha_); + int32_t casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + int8_t casted_alpha = type_convert(alpha_); + int8_t casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + half_t casted_alpha = type_convert(alpha_); + half_t casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + bhalf_t casted_alpha = type_convert(alpha_); + bhalf_t casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } +}; + +struct LeakyRelu final : public UnaryOpBase +{ + __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; + __host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default; + __host__ __device__ ~LeakyRelu() = default; + + __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {} + + __host__ __device__ float get_alpha() const { return alpha_; } + + const float alpha_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + double casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + int32_t casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + int8_t casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + half_t casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + + __host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y, + [[maybe_unused]] const bhalf_t& x) const final + { + } +}; + +struct Elu final : public UnaryOpBase +{ + __host__ __device__ constexpr Elu(const Elu&) = default; + __host__ __device__ constexpr Elu(Elu&&) = default; + __host__ __device__ ~Elu() = default; + + __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {} + + __host__ __device__ float get_alpha() const { return alpha_; } + + const float alpha_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + double casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + int32_t casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + int8_t casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + half_t casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + bhalf_t casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } +}; + +struct Logistic final : public UnaryOpBase +{ + __host__ __device__ constexpr Logistic(const Logistic&) = default; + __host__ __device__ constexpr Logistic(Logistic&&) = default; + __host__ __device__ ~Logistic() = default; + + __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {} + + __host__ __device__ float get_alpha() const { return alpha_; } + + const float alpha_; + + __host__ __device__ inline void operator()(float& y, const float& x) const final + { + float casted_alpha = type_convert(alpha_); + constexpr float one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + + __host__ __device__ inline void operator()(double& y, const double& x) const final + { + double casted_alpha = type_convert(alpha_); + constexpr double one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + + __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final + { + int32_t casted_alpha = type_convert(alpha_); + constexpr int32_t one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + + __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + { + int8_t casted_alpha = type_convert(alpha_); + constexpr int8_t one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + + __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + { + half_t casted_alpha = type_convert(alpha_); + constexpr half_t one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + + __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + { + bhalf_t casted_alpha = type_convert(alpha_); + constexpr bhalf_t one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } +}; + +struct ConvInvscale +{ + __host__ __device__ ConvInvscale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScale +{ + __host__ __device__ ConvScale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScaleRelu +{ + __host__ __device__ ConvScaleRelu(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + float x; + Relu{}(x, c * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +// support fastconvert of int8 to fp16 + +template +struct FastNumericArrayConverter +{ +}; + +template <> +struct FastNumericArrayConverter +{ + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + OutputArray Output; + + uint32_t* half_2 = reinterpret_cast(&Output); + uint32_t const uint8_4 = reinterpret_cast(Input); + + static constexpr uint32_t byte_selector_01 = 0x05010500; + static constexpr uint32_t byte_selector_23 = 0x05030502; + static constexpr uint32_t fp16_adder = 0x64646464; + half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); + half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[0]) + : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[1]) + : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +template +struct FastNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + FastNumericArrayConverter converter; + + OutputArray Output; + + using Vec_InputArray = vector_type; + using Vec_OutputArray = vector_type; + + Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); + Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); + + static_for<0, N / VEC_WIDTH, 1>{}( + [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); }); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +struct DynamicUnaryOp +{ + + DynamicUnaryOp& operator=(const DynamicUnaryOp& other) + { + if(this != &other) + { + unary_op_ptr_ = other.unary_op_ptr_; + unary_op_type_ = other.unary_op_type_; + } + return *this; + } + + __host__ __device__ DynamicUnaryOp() = delete; + + __host__ __device__ DynamicUnaryOp(const Swish& swish) + { + unary_op_type_ = UnaryOpType::Swish; + beta = swish.get_beta(); + } + + __host__ __device__ DynamicUnaryOp(const Swish&& swish) + { + unary_op_type_ = UnaryOpType::Swish; + beta = swish.get_beta(); + } + + __host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } + + __host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } + + __host__ __device__ DynamicUnaryOp(const PassThrough&) + { + unary_op_type_ = UnaryOpType::PassThrough; + } + + __host__ __device__ DynamicUnaryOp(const PassThrough&&) + { + unary_op_type_ = UnaryOpType::PassThrough; + } + + __host__ __device__ DynamicUnaryOp(const Logistic& logistic) + { + unary_op_type_ = UnaryOpType::Logistic; + alpha = logistic.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const Logistic&& logistic) + { + unary_op_type_ = UnaryOpType::Logistic; + alpha = logistic.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } + + __host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } + + __host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } + + __host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } + + __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) + { + unary_op_type_ = UnaryOpType::SoftRelu; + alpha = softrelu.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) + { + unary_op_type_ = UnaryOpType::SoftRelu; + alpha = softrelu.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + + __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + + __host__ __device__ DynamicUnaryOp(const Power& pow) + { + unary_op_type_ = UnaryOpType::Power; + alpha = pow.get_alpha(); + beta = pow.get_beta(); + gamma = pow.get_gamma(); + } + + __host__ __device__ DynamicUnaryOp(const Power&& pow) + { + unary_op_type_ = UnaryOpType::Power; + alpha = pow.get_alpha(); + beta = pow.get_beta(); + gamma = pow.get_gamma(); + } + + __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) + { + unary_op_type_ = UnaryOpType::ClippedRelu; + alpha = clippedrelu.get_alpha(); + beta = clippedrelu.get_beta(); + } + + __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) + { + unary_op_type_ = UnaryOpType::ClippedRelu; + alpha = clippedrelu.get_alpha(); + beta = clippedrelu.get_beta(); + } + + __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) + { + unary_op_type_ = UnaryOpType::LeakyRelu; + alpha = leakyrelu.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) + { + unary_op_type_ = UnaryOpType::LeakyRelu; + alpha = leakyrelu.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const Elu& elu) + { + unary_op_type_ = UnaryOpType::Elu; + alpha = elu.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const Elu&& elu) + { + unary_op_type_ = UnaryOpType::Elu; + alpha = elu.get_alpha(); + } + + __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) + : unary_op_type_(dynamic_op.unary_op_type_), + unary_op_ptr_(dynamic_op.unary_op_ptr_), + alpha(dynamic_op.alpha), + beta(dynamic_op.beta), + gamma(dynamic_op.gamma) + { + } + + __host__ __device__ ~DynamicUnaryOp() + { + switch(unary_op_type_) + { + case(UnaryOpType::Swish): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Sigmoid): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::PassThrough): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Logistic): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::TanH): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Relu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::SoftRelu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::UnaryAbs): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Power): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::ClippedRelu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::LeakyRelu): delete static_cast(unary_op_ptr_); break; + case(UnaryOpType::Elu): delete static_cast(unary_op_ptr_); break; + + default: break; + } + } + + __device__ void InitUnaryOpPtrOnDevice() + { + switch(unary_op_type_) + { + case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break; + case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break; + case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break; + case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break; + case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break; + case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break; + case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break; + case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break; + case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break; + case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break; + case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break; + case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break; + + default: unary_op_ptr_ = nullptr; break; + } + } + + template + __device__ void operator()(Y& y, const X& x) const + { + isSupported(); + unary_op_ptr_->operator()(y, x); + } + + template + __host__ void operator()(Y& y, const X& x) const + { + isSupported(); + switch(unary_op_type_) + { + case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; + case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; + case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; + case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; + case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; + case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; + case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; + case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; + case(UnaryOpType::Power): Power{}.operator()(y, x); break; + case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; + case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; + case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; + default: break; + } + } + + template + __device__ __host__ constexpr void isSupported() const + { + + static_assert(std::is_same::value, "X and Y must be of the same type"); + + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + } + + private: + enum class UnaryOpType + { + Swish, + Sigmoid, + PassThrough, + Logistic, + TanH, + Relu, + SoftRelu, + UnaryAbs, + Power, + ClippedRelu, + LeakyRelu, + Elu }; - float beta_ = 1.0f; + public: + UnaryOpType unary_op_type_; + UnaryOpBase* unary_op_ptr_ = nullptr; + float alpha; + float beta; + float gamma; }; +#pragma clang diagnostic pop } // namespace element_wise } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index e71ade77832ca6e6db873e60b33725a20430a65e..20f2cd4a0d6b79d37898572193d5e7d0fa87ee19 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -27,10 +27,10 @@ struct BlockToCTileMap_M00_N0_M01 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 1) + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) { } @@ -54,8 +54,8 @@ struct BlockToCTileMap_M00_N0_M01 } template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const { if constexpr(DeviceCTileIndexCheck) return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); @@ -63,7 +63,7 @@ struct BlockToCTileMap_M00_N0_M01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -138,6 +138,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif } template @@ -259,6 +264,276 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, + index_t N, + index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + if(M0 == 1) + { + return make_tuple(0, block_1d_id); + } + else if(N0 == 1) + { + return make_tuple(block_1d_id, 0); + } + // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + else + { + const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum); + const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0); + auto group_id_x = block_1d_id % GroupNum; + auto group_id_y = block_1d_id / GroupNum; + auto remap_block_1d_id = + group_id_x <= big_group_num + ? group_id_x * group_size + group_id_y + : group_id_x * group_size + big_group_num - group_id_x + group_id_y; + + index_t idx_N0 = remap_block_1d_id % N0; + index_t idx_M0 = remap_block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t M01_; +}; + +// columns of row-vectors +// This C-tile map dynamically adjusts N01 when C-tile index is out of range +template +struct BlockToCTileMap_N00_M0_N01Adapt; + +template +struct BlockToCTileMap_N00_M0_N01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8) + : M_(M), N_(N), N01_(N01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_); + } +#endif + } + + template + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t N01 = 8) + : BlockToCTileMap_N00_M0_N01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01) + { + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_M0 = block_1d_id % M0; + index_t idx_N0 = block_1d_id / M0; + + const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_; + + index_t idx_N00 = idx_N0 / N01_; + index_t idx_N01 = idx_N0 % N01_; + index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0; + + /** + * idxN0 + * + * |< mtx N >| + * + * |<---N01--->| + * - |-----------|-----------|-----------|-----|-----|- + * ^ | 0 ----------> 1 | | | | + * | | / | | | | M_0 MPerBlock + * | / | | | | + * |------/----------------|-----------|-----|-----|- + * | | | | | | | + * idxM0 | V | | | | | M_1 MPerBlock + * | 2 ----------> 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | blockid | | | | + * | | 5 | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * Example: + * assume: + * N0 = 5 + * M0 = 4 + * block_1d_id = 5 + * N01 = 2 + * + * idx_M0 = 1 + * idx_N0 = 1 + * N01_adapt = 2 + * idx_N00 = 0 + * idx_N01 = 1 + * idx_M0_N01_local = 5 + * output {2, 1} + */ + + return make_tuple(idx_M0_N01_local / N01_adapt, + idx_M0_N01_local % N01_adapt + idx_N00 * N01_); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t N01_; +}; + // 2D slices of column-vectors in 3D space // This C-tile map dynamically adjusts M01 when C-tile index is out of range template @@ -322,7 +597,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt return true; // always valid provided that user gets grid size from CalculateGridSize() } - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } private: index_t M01_; @@ -380,7 +658,7 @@ struct BlockToCTileMap_M00_N00_M01_N01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -492,7 +770,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -594,7 +872,8 @@ struct OffsettedBlockToCTileMap { using underlying_type = UnderlyingBlockToCTileMap; - OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) + __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t block_start) { block_to_ctile_map_ = block_to_ctile_map; block_start_ = block_start; @@ -615,7 +894,7 @@ struct OffsettedBlockToCTileMap } template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); } @@ -626,9 +905,59 @@ struct OffsettedBlockToCTileMap return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); } + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + UnderlyingBlockToCTileMap block_to_ctile_map_; index_t block_start_; }; +// second version with 2 offsets +template +struct OffsettedBlockToCTileMap2 +{ + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t group_offset, + index_t tile_offset) + : block_to_ctile_map_{block_to_ctile_map}, + group_offset_{group_offset}, + tile_offset_{tile_offset} + { + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + + __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t group_offset_; + index_t tile_offset_; +}; /** * @brief Simple tile mapping which creates 3D grid of block of threads. @@ -672,7 +1001,7 @@ struct BlockToCTileMap_3DGrid_KSplit } template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } @@ -1080,4 +1409,326 @@ struct BlockToCTileMap_GemmStreamK } }; +template +struct BlockToCTileMap_GemmStreamK_v2 +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + mutable uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv equiv_tiles_big; // for reduction + MDiv equiv_tiles_little; // for reduction + + // prefer construct on host + __host__ __device__ BlockToCTileMap_GemmStreamK_v2( + uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) + { + // total output tiles + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + + // default to regular DP GEMM if sk blocks == 0 + if(streamk_sel == 0) + { + sk_num_blocks = 0; + dp_tiles = num_tiles; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + // 2-tile sk + DP GEMM + else + { + + // check if there's enough work for DP+ stream-k + bool bigEnough = num_tiles > grid_size; + // select between stream-k strategies + uint32_t sk_tiles = 0; + if(streamk_sel == 1) // 1 tile stream-k + { + sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 2) // 2-tile stream-k + { + sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 3) // 3-tile stream-k + { + sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) + : num_tiles; + } + else if(streamk_sel == 4) // 4-tile stream-k + { + sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) + : num_tiles; + } + sk_num_blocks = sk_tiles; + // remaining tiles are DP tiles + dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; + + sk_total_iters = k_iters_per_tile.get() * sk_tiles; + + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = sk_num_blocks; + } + + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + // using multiple blocks for parallel reduction + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ index_t get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + return reduction_start_block_idx + get_sk_tiles(); + } + else + return reduction_start_block_idx; + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& equiv_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index b25f136a3713a412fbcb166f3df304a172ee42f7..206ea00b9d0a94dcd5f12c02948d8f84f9223fb5 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index bd1e0585fc96ad29ee9e5336993f2ba6570fb31d..2d9197a7f443f1b9af7796a826a0a326d8d68de5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock if(thread_k_cluster_id == 0) { - if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR])) + if(!float_equal_zero{}(beta_values[iR])) { StaticBuffer priorDstValueBuf; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index 203be3c42d9d23c935d10f34ceb44245f7875ad0..774df1f99348b4fabddee23141a4d6db0a5b1363 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -244,7 +244,7 @@ struct GridwiseReduction_mk_to_m_multiblock if(thread_k_cluster_id == 0) { - if(block_group_size == 0 && !float_equal_zero{}(beta)) + if(!float_equal_zero{}(beta)) { StaticBuffer priorDstValueBuf; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3c50ef06caae8c2cc9467cb72c8f6a86d92ba83 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/tuple_helper.hpp" + +namespace ck { + +template +__global__ void +kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k, + const DsGridDesc_M ds_grid_desc_m, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op, + const InDataType* const __restrict__ p_in_value_global, + const DsGridPointer p_ds_value_global, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + ds_grid_desc_m, + out_grid_desc_m, + in_elementwise_op, + out_elementwise_op, + p_in_value_global, + p_ds_value_global, + p_out_value_global); +} + +template +struct GridwiseReduction_mk_to_m_threadwise_multi_d +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const DsGridDesc_M& ds_grid_desc_m, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const OutElementwiseOperation& out_elementwise_op, + const InDataType* const __restrict__ p_in_value_global, + const DsGridPointer p_ds_grid, + OutDataType* const __restrict__ p_out_value_global) + { + using ThreadwiseReduce = ThreadwiseReduction; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto dst_global_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + auto ds_thread_buf = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto ds_global_buf = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize()); + }, + Number{}); + + auto ds_global_load = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + InSrcVectorDim, // SrcVectorDim + DsVectorSize{}[I], + 1, // SrcScalarStrideInVector + true>{ + ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)}; + }, + Number{}); + + static_for<0, NumDTensor, 1>{}([&](auto I) { + ds_global_load(I).Run(ds_grid_desc_m[I], + ds_global_buf[I], + reduced_data_desc, + make_tuple(I0), + ds_thread_buf(I)); + }); + + StaticBuffer out_value_buf; + + // if constexpr(NumDTensor > 0) + { + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(accu_value_buf[I]), + generate_tie( + [&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, + Number{})); + + unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs); + }); + } + + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThrough{}); + + threadwise_dst_store.Run( + reduced_data_desc, make_tuple(I0), out_value_buf, out_grid_desc_m, dst_global_buf); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index c2f47bd44435bff2ea75cbbee625895cb0c1bef4..9469fa7bc7b27351f393caed43a9bc12a2b8780d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index d2920570e4f2cefc883aa9bf4bf1d8b845146670..42f7c2a33fbd77ec58ed16889860924b1f8f8dbb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; // ck::Tuple static constexpr auto MakeD0sGridPointer() @@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{} + Gemm1KPack * XdlopsGemm{} .K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 18cfeebcf30d5ea8e44c9a5a5ff1cf564fef01ba..bc76d4cc4fb9e905cdb1547273866b66575ff334 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1754e07e6a2b98a2de3b0056cfda509c76eea0c2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -0,0 +1,1604 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] +// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] +template +struct GridwiseBatchedGemmSoftmaxGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto L0PerBlock = LTilePerBlock / L1Value; + static constexpr auto AL0 = Number{}; + static constexpr auto AL1 = Number{}; + static constexpr auto BL0 = Number{}; + static constexpr auto BL1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + static constexpr auto WmmaL = 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / AK1; + constexpr auto max_lds_align = AK1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / AK1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + AK1), + make_tuple(Number{} * Number{} * AK1, + Number{} * AK1, + Number{} * AK1, + AK1, + AK1, + AK1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeB0BlockDescriptor() + { + constexpr auto b0_block_desc = [&]() { + if constexpr(B0EnableLds) + { + // K0->L->BK1 Per Block + constexpr auto K0PerBlock = KPerBlock / BK1; + constexpr auto max_lds_align = BK1; + + if constexpr(B0BlockLdsExtraL) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / BK1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BK1), + make_tuple(Number{} * Number{} * BK1, + Number{} * BK1, + Number{} * BK1, + BK1, + BK1, + BK1, + I1)); + } + }(); + + return b0_block_desc; + } + + __host__ __device__ static constexpr auto MakeB1BlockDescriptor() + { + constexpr auto b1_block_desc = [&]() { + if constexpr(B1EnableLds) + { + // L0->N->BL1 Per Block + constexpr auto max_lds_align = BL1; + + if constexpr(B1BlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BL1), + make_tuple(Number{} * BL1, BL1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BL1), max_lds_align); + } + } + else + { + constexpr auto LWmmaPerblock = LPerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL / 2 / BL1; + // LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BL1), + make_tuple(Number{} * Number{} * BL1, + Number{} * BL1, + Number{} * BL1, + BL1, + BL1, + BL1, + I1)); + } + }(); + + return b1_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / AK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep() + { + constexpr auto b0_block_copy_step = [&]() { + if constexpr(B0EnableLds) + { + constexpr auto K0PerBlock = KPerBlock / BK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b0_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep() + { + constexpr auto b1_block_copy_step = [&]() { + if constexpr(B1EnableLds) + { + return make_multi_index(L0PerBlock, 0, 0); + } + else + { + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + + return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b1_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) + { + + constexpr auto b0_wave_desc = [&]() { + if constexpr(B0EnableLds) + { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + B0BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = B0BlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = B0BlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b0_wave_desc; + } + + template + __host__ __device__ static constexpr auto + MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&) + { + constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); + constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); + constexpr auto A_LRow = I1; + return transform_tensor_descriptor( + A1BlockDesc_AL0_M_AL1{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_LRow)), + make_unmerge_transform(make_tuple(Number{}, I1, I1)), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + template + __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&) + { + + constexpr auto b1_wave_desc = [&]() { + if constexpr(B1EnableLds) + { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else + constexpr auto B_LRow = I1; +#endif + return transform_tensor_descriptor( + B1BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0); + constexpr auto L0PerWmma = B1BlockDesc_{}.GetLength(I3); + constexpr auto B_LRow = B1BlockDesc_{}.GetLength(I4); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b1_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + const index_t gemm0_bytes_end = + (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) + + SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType)); + + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + + SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType)); + + const index_t softmax_bytes_end = + SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType); + + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (LPerBlock % (LPerWmma * LRepeat)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetB0ProblemsizeLK = [&]() { + if constexpr(B0EnableLds) + { + return make_tuple(b0_grid_desc.GetLength(I1), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * + b0_grid_desc.GetLength(I5), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) * + b0_grid_desc.GetLength(I4) * b0_grid_desc.GetLength(I6)); + } + }; + + const auto GetB1ProblemsizeNL = [&]() { + if constexpr(B1EnableLds) + { + return make_tuple(b1_grid_desc.GetLength(I1), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) * + b1_grid_desc.GetLength(I5), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) * + b1_grid_desc.GetLength(I4) * b1_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto L = GetB0ProblemsizeLK()(I0); + const auto K = GetAProblemsizeMK()[I1]; + const auto N = GetB1ProblemsizeNL()(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + return false; + } + + if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + printf("GridwiseOp: outer loop unsupport\n"); + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(LPerBlock % LTilePerBlock == 0)) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + return false; + } + + const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + printf("GridwiseOp: inner loop unsupport\n"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, KPerBlock); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b0_block_space_size_aligned = + B0EnableLds ? math::integer_least_multiple( + MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + static constexpr auto b1_block_space_size_aligned = + B1EnableLds ? math::integer_least_multiple( + MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a_block_space_size_aligned; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + // Feature to add, IntraThread Reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const B0ElementwiseOperation& b0_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const C0MatrixMask& c0_matrix_mask, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// set up Gemm0 +/*******************************************************************************/ + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b0_block_desc = MakeB0BlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto AK0PerBlock = KPerBlock/ AK1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/AK1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b0_block_trait = [&](){ + if constexpr(B0EnableLds) + { + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + SharedMemTrait::b0_block_space_size_aligned); + + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0DataType, + B0DataType, + decltype(b0_grid_desc), + decltype(b0_block_desc), + B0BlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + 1, + 1, + B0ThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b0_grid_desc, + make_multi_index(0, 0, 0), + b0_element_op, + b0_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/BK1Value; + auto b0_block_buf = make_static_buffer( + b0_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b0_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B0BlockTransferSrcScalarPerVector, + B0ThreadTransferSrcResetCoordinateAfterRun, + true>( + b0_grid_desc, + make_multi_index(0, + 0/(LWaves * LPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b0_block_buf = b0_block_trait()[I0]; + auto b0_blockwise_copy = b0_block_trait()[I1]; + +/*******************************************************************************/ + // Gemm0 + constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK); + + auto blockwise_gemm0 = BlockwiseGemmWMMA< + BlockSize, + ADataType, + B0DataType, + Acc0DataType, + decltype(MakeAWaveDescriptor(a_block_desc)), + decltype(MakeB0WaveDescriptor(b0_block_desc)), + MPerBlock, + LPerBlock, + KPerBlock, + MPerWmma, + LPerWmma, + MRepeat, + LRepeat, + KPack, + AEnableLds, + B0EnableLds, + true>{}; // C' = B' x A' + + + // Prepare Register for A*B0 matrix + auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); + + constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( + acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)), + make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)), + make_pass_through_transform(laccvgprs)), + make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep(); + + const auto a_block_reset_copy_step = [&](){ + if constexpr(AEnableLds){ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0); + } + else{ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0, 0); + } + }(); + + const auto b0_block_reset_copy_step = [&](){ + if constexpr(B0EnableLds){ + return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0); + } + else{ + return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0, 0); + } + }(); + + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); +/*******************************************************************************/ +// softmax +/*******************************************************************************/ + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + // get acc0 7D thread cluster + constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() / + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0); + constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1); + constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2); + constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3); + constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4); + constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5); + constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6); + // get acc0 thread map + constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform( + make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_l_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs)); + + constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); +/*******************************************************************************/ +// set up Gemm1 +/*******************************************************************************/ + // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm + // A1 matrix in VGPR + constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( + Number{}, + Number{}, + Number{}); + + constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0]; + constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1]; + constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2]; + + constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor( + make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1), + make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1)); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + ADataType, + decltype(acc0_thread_desc_l0perblock_mperblock_l1), + decltype(a1_thread_desc_l0perblock_mperblock_l1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2>, + 2, + laccvgprs>{tensor_operation::element_wise::PassThrough{}}; + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); + + constexpr auto b1_block_desc = MakeB1BlockDescriptor(); + + auto b1_block_trait = [&](){ + if constexpr(B1EnableLds) + { + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + SharedMemTrait::b1_block_space_size_aligned); + + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ B1ElementwiseOperation, +/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1, +/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ B1DataType, +/* typename DstData, */ B1DataType, +/* typename SrcDesc, */ decltype(b1_grid_desc), +/* typename DstDesc, */ decltype(b1_block_desc), +/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL/2/L1Value; + auto b1_block_buf = make_static_buffer( + b1_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b1_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B1BlockTransferSrcScalarPerVector, + B1ThreadTransferSrcResetCoordinateAfterRun, + true>( + b1_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + }; + + auto b1_block_buf = b1_block_trait()[I0]; + auto b1_blockwise_copy = b1_block_trait()[I1]; + + constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep(); + + auto blockwise_gemm1 = + BlockwiseGemmWMMA{make_tuple(0, 0, 0, 0, 0, 0)}; + + auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const auto L = [&](){ + if constexpr(B0EnableLds){ + return b0_grid_desc.GetLength(I1); + } + else{ + return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I5); + } + }(); + + const index_t num_gemm1_l_block_outer_loop = L / LPerBlock; + constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock; + + // Initialize C + StaticBuffer c_thread_buf; + c_thread_buf.Clear(); + +/*******************************************************************************/ + // + // Kernel Main Stage + // + // Flash Attention + // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022). + index_t gemm1_l_block_outer_index = 0; + // Outer loop, along GEMM_L + // Inner loop, along GEMM_K + do{ + auto l_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock)) + { + continue; + } + // gemm0 start, A-B swaped + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b0_grid_desc, + b0_block_desc, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + blockwise_gemm0, + acc0_thread_buf, + KBlockMainLoop); + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 7d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + // 7d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + constexpr auto MREPEAT = c_block_lengths[I0]; + constexpr auto MWAVE = c_block_lengths[I1]; + constexpr auto MTHREADSubGroup = c_block_lengths[I2]; + constexpr auto LREPEAT = c_block_lengths[I3]; + constexpr auto LWAVE = c_block_lengths[I4]; + constexpr auto LSUBGROUP = c_block_lengths[I5]; + constexpr auto LACCVGPRS = c_block_lengths[I6]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D( + Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)), + make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto l_global = l_local + l_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, l_global)) + { + acc0_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); + } + }); + } + else + { static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + } + + block_sync_lds(); + // Tiled softmax start + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc0_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + + // main body + if constexpr(num_gemm1_l_block_inner_loop > 1) + { + static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) { + // Data cast from Acc0DataType to ADataType happen here + a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple( + Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup, + c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, + a_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc, + b0_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = + blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); + constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); + constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); + constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma + )), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NSubGroup, + NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 8, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index f4b82badf14584d5e59fda6dc6751a83aa67b32f..afb2ad2e760396c200930254f871ff13e032a91a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp deleted file mode 100644 index d686c14b350953a48bf132694251ac677fb47c4f..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp +++ /dev/null @@ -1,195 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { - -template -__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, - const OutGrid1dDescTuple out_grid_1d_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const ElementwiseOperation elementwise_op) -{ - GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple, - out_grid_1d_desc_tuple, - p_in_global_tuple, - p_out_global_tuple, - elementwise_op); -} - -template -struct GridwiseElementwise_1D -{ - static constexpr index_t NumInput = InDataTypePointerTuple::Size(); - static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); - - static_assert(NumInput == InScalarPerVectorSeq::Size() && - NumOutput == OutScalarPerVectorSeq::Size() && - NumInput == InGrid1dDescTuple::Size() && - NumOutput == OutGrid1dDescTuple::Size(), - "Tuple size is inconsistent with the number of in/out!"); - - static constexpr auto I0 = Number<0>{}; - - static constexpr auto thread_buffer_desc_m = - make_naive_tensor_descriptor_packed(make_tuple(Number{})); - - using PassThroughOp = tensor_operation::element_wise::PassThrough; - - __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, - const OutGrid1dDescTuple out_grid_1d_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const ElementwiseOperation elementwise_op) - { - const index_t thread_global_id = get_thread_global_1d_id(); - - auto in_thread_buf_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_cv_t>; - - return StaticBuffer{}; - }, - Number{}); - - auto out_thread_buf_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_pointer_t; - - return StaticBuffer{}; - }, - Number{}); - - auto in_global_buf_tuple = generate_tuple( - [&](auto I) { - static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); - - return make_dynamic_buffer( - p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize()); - }, - Number{}); - - auto out_global_buf_tuple = generate_tuple( - [&](auto I) { - static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); - - return make_dynamic_buffer( - p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize()); - }, - Number{}); - - const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread); - - const index_t blockSize = get_block_size(); - const index_t blockPerGrid = get_grid_size(); - const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0); - const index_t loop_step = blockPerGrid * blockSize * MPerThread; - const auto loop_step_index = make_multi_index(loop_step); - - auto in_global_load_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_cv_t>; - - return ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0>, // DimAccessOrder - 0, // SrcVectorDim - InScalarPerVectorSeq::At( - I), // ScalarPerVector - 1, // SrcScalarStrideInVector - false>{in_grid_1d_desc_tuple[I], - thread_global_offset}; - }, - Number{}); - - auto out_global_store_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_pointer_t; - - return ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths - Sequence<0>, // DimAccessOrder - 0, // SrcVectorDim - OutScalarPerVectorSeq::At(I), - InMemoryDataOperationEnum::Set, - 1, - false>( - out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{}); - }, - Number{}); - - index_t num_iter = M / (loop_step); - do - { - static_for<0, NumInput, 1>{}([&](auto I) { - in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I], - in_global_buf_tuple[I], - thread_buffer_desc_m, - make_tuple(I0), - in_thread_buf_tuple(I)); - - in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I], - loop_step_index); - }); - - static_for<0, MPerThread, 1>{}([&](auto iM) { - // get reference to in data - const auto in_data_refs = generate_tie( - // return type should be lvalue - [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); }, - Number{}); - - // get reference to dst data - auto out_data_refs = generate_tie( - // return type should be lvalue - [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); }, - Number{}); - - unpack2(elementwise_op, out_data_refs, in_data_refs); - }); - - static_for<0, NumOutput, 1>{}([&](auto I) { - out_global_store_tuple(I).Run(thread_buffer_desc_m, - make_tuple(I0), - out_thread_buf_tuple[I], - out_grid_1d_desc_tuple[I], - out_global_buf_tuple(I)); - - out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I], - loop_step_index); - }); - } while(--num_iter); - } -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13e9f7bd5ec822ecefa84049ce0b04eb7e630411 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const UnaryOperation unary_op, + const Scale scale_op) +{ + GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op, + unary_op, + scale_op); +} + +template +struct GridwiseElementwise_1D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid1dDescTuple::Size() && + NumOutput == OutGrid1dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const UnaryOperation unary_op, + const Scale scale_op) + { + const index_t thread_global_id = get_thread_global_1d_id(); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread); + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InScalarPerVectorSeq::At( + I), // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc_tuple[I], + thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + false>( + out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter = M / (loop_step); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I], + loop_step_index); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + // get reference to in data + auto uop_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(unary_op, uop_data_refs, uop_data_refs); + + auto sop_in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + auto sop_out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(scale_op, sop_out_data_refs, sop_in_data_refs); + + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_m, + make_tuple(I0), + out_thread_buf_tuple[I], + out_grid_1d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I], + loop_step_index); + }); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index bf0e8c186c75cef05b08c7ef9e1bf16c016830a3..2bbe4c8ce8544bdcb74bafda59cc03da2f80352e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -1,229 +1,282 @@ // SPDX-License-Identifier: MIT -// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -// +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" #include "ck/utility/data_type.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/common_header.hpp" namespace ck { -template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) +{ + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, + out_grid_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + block_2_tile_map, + elementwise_op); +} + +template -__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple, - const OutGrid2dDescTuple out_grid_2d_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const ElementwiseOperation elementwise_op, - const index_t num_threads_m, - const index_t num_threads_n) +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InDataTypePointerTuple p_in_global_tuple_a, + const InDataTypePointerTuple p_in_global_tuple_b, + const OutDataTypePointerTuple p_out_global_tuple_a, + const OutDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size) { - GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple, - out_grid_2d_desc_tuple, - p_in_global_tuple, - p_out_global_tuple, - elementwise_op, - num_threads_m, - num_threads_n); + if(get_block_1d_id() < a_grid_size) + { + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a, + out_grid_desc_tuple_a, + p_in_global_tuple_a, + p_out_global_tuple_a, + block_2_tile_map_a, + elementwise_op, + get_block_1d_id()); + } + else + { + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b, + out_grid_desc_tuple_b, + p_in_global_tuple_b, + p_out_global_tuple_b, + block_2_tile_map_b, + elementwise_op, + get_block_1d_id() - a_grid_size); + } } -template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op, + const index_t batch_count, + const std::array input_batch_strides, + const std::array output_batch_strides) +{ + static_assert(InGridDescTuple::Size() == NumInputs && + InDataTypePointerTuple::Size() == NumInputs); + static_assert(OutGridDescTuple::Size() == NumOutputs && + OutDataTypePointerTuple::Size() == NumOutputs); + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + InDataTypePointerTuple p_in_global_with_offset_tuple; + OutDataTypePointerTuple p_out_global_with_offset_tuple; + + static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx; + }); + + static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_out_global_with_offset_tuple(i) = + p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx; + }); + + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, + out_grid_desc_tuple, + p_in_global_with_offset_tuple, + p_out_global_with_offset_tuple, + block_2_tile_map, + elementwise_op); +} + +template -struct GridwiseElementwise_2D + typename OutScalarPerVectorSeq, + index_t SrcVectorDim, + index_t DstVectorDim> +struct GridwiseElementwise { static constexpr index_t NumInput = InDataTypePointerTuple::Size(); static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); static_assert(NumInput == InScalarPerVectorSeq::Size() && NumOutput == OutScalarPerVectorSeq::Size() && - NumInput == InGrid2dDescTuple::Size() && - NumOutput == OutGrid2dDescTuple::Size(), + NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(), "Tuple size is inconsistent with the number of in/out!"); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - static constexpr auto thread_buffer_desc_mn = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) && + (DstVectorDim == I0 || DstVectorDim == I1), + "Vector dim must be equal to 0 or 1."); using PassThroughOp = tensor_operation::element_wise::PassThrough; - __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, - const OutGrid2dDescTuple out_grid_2d_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const ElementwiseOperation elementwise_op, - const index_t num_threads_m, - const index_t num_threads_n) + __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple, + const OutGridDescTuple& out_grid_desc_tuple, + const InDataTypePointerTuple& p_in_global_tuple, + const OutDataTypePointerTuple& p_out_global_tuple, + const Block2TileMap& block_2_tile_map, + const ElementwiseOperation& elementwise_op, + const index_t block_id = get_block_1d_id()) { - auto in_thread_buf_tuple = generate_tuple( + + constexpr auto src_datas = generate_tuple( [&](auto I) { using DataTypePointer = remove_cvref_t; using DataType = remove_cv_t>; - return StaticBuffer{}; + return DataType{}; }, Number{}); - auto out_thread_buf_tuple = generate_tuple( + constexpr auto dst_datas = generate_tuple( [&](auto I) { using DataTypePointer = remove_cvref_t; using DataType = remove_pointer_t; - return StaticBuffer{}; + return DataType{}; }, Number{}); - auto in_global_buf_tuple = generate_tuple( + const auto in_global_buf_tuple = generate_tuple( [&](auto I) { return make_dynamic_buffer( - p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); + p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize()); }, Number{}); auto out_global_buf_tuple = generate_tuple( [&](auto I) { return make_dynamic_buffer( - p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize()); + p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize()); }, Number{}); - const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0); - const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1); - - const index_t loop_step_m = num_threads_m * MPerThread; - const index_t loop_step_n = num_threads_n * NPerThread; + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); - const index_t thread_1d_id = get_thread_global_1d_id(); - index_t tid_m = thread_1d_id / num_threads_n; - index_t tid_n = thread_1d_id % num_threads_n; - - const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread); - - auto in_global_load_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_cv_t>; - - return ThreadwiseTensorSliceTransfer_v2< - DataType, - DataType, - decltype(in_grid_2d_desc_tuple[I]), - decltype(thread_buffer_desc_mn), - Sequence, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 0, // SrcVectorDim - InScalarPerVectorSeq::At(I), // ScalarPerVector - 1, // SrcScalarStrideInVector - true>{in_grid_2d_desc_tuple[I], thread_global_offset}; + const index_t m0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); + const index_t m1_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); + const auto input_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); }, Number{}); - - auto out_global_store_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_pointer_t; - - return ThreadwiseTensorSliceTransfer_v1r3< - DataType, - DataType, - decltype(thread_buffer_desc_mn), - decltype(out_grid_2d_desc_tuple[I]), - PassThroughOp, - Sequence, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // OutScalarPerVectorSeq::At(I), - InMemoryDataOperationEnum::Set, - 1, - true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + const auto output_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); }, Number{}); - index_t num_iter_m = M / (loop_step_m); - do - { - index_t num_iter_n = N / (loop_step_n); - do - { - static_for<0, NumInput, 1>{}([&](auto I) { - in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], - in_global_buf_tuple[I], - thread_buffer_desc_mn, - make_tuple(I0, I0), - in_thread_buf_tuple(I)); - - in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], - make_multi_index(0, loop_step_n)); - }); - - static_for<0, MPerThread, 1>{}([&](auto iM) { - static_for<0, NPerThread, 1>{}([&](auto iN) { - constexpr auto offset = - thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN)); - // get reference to in data - const auto in_data_refs = generate_tie( - // return type should be lvalue - [&](auto I) -> const auto& { - return in_thread_buf_tuple(I)(Number{}); - }, - Number{}); - - // get referenec to dst data - auto out_data_refs = generate_tie( - // return type should be lvalue - [&](auto I) -> auto& { - return out_thread_buf_tuple(I)(Number{}); - }, - Number{}); - unpack2(elementwise_op, out_data_refs, in_data_refs); - }); - }); - - static_for<0, NumOutput, 1>{}([&](auto I) { - out_global_store_tuple(I).Run(thread_buffer_desc_mn, - make_tuple(I0, I0), - out_thread_buf_tuple[I], - out_grid_2d_desc_tuple[I], - out_global_buf_tuple(I)); - - out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I], - make_multi_index(0, loop_step_n)); - }); - - } while(--num_iter_n); - - static_for<0, NumInput, 1>{}([&](auto I) { - in_global_load_tuple(I).MoveSrcSliceWindow( - in_grid_2d_desc_tuple[I], - make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n)); - }); - - static_for<0, NumOutput, 1>{}([&](auto I) { - out_global_store_tuple(I).MoveDstSliceWindow( - out_grid_2d_desc_tuple[I], - make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n)); - }); - } while(--num_iter_m); + using ThisThreadBlock = ThisThreadBlock; + // If src and dst have same vector dim, then: + // M0 dim - for src and dst vector load/store + // else: + // M0 dim - for dst vector load + // M1 dim - for src vector store + using SrcDimAccessOrder = + std::conditional_t, Sequence<1, 0>>; + using DstDimAccessOrder = + std::conditional_t, Sequence<1, 0>>; + + using ThreadClusterLengths = + Sequence{}, Number{}>; + + auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2< + ThisThreadBlock, + ElementwiseOperation, + uniform_sequence_gen_t(InMemoryDataOperationEnum::Set)>, + Sequence, + ThreadClusterLengths, + ThreadClusterArrangeOrder, + decltype(src_datas), + decltype(dst_datas), + InGridDescTuple, + OutGridDescTuple, + SrcDimAccessOrder, + DstDimAccessOrder, + SrcVectorDim, + DstVectorDim, + InScalarPerVectorSeq, + OutScalarPerVectorSeq, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t>{in_grid_desc_tuple, + input_thread_grid_offset, + out_grid_desc_tuple, + output_thread_grid_offset, + elementwise_op}; + global_to_global_transfer.Run( + in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index 3ea72b85345869a938e52f872c2c586c6e36170f..072275c089ce469a1d7e043799e290a28c0b310c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, - XDataType* const __restrict__ p_x_lds, + XDataType* const __restrict__ p_x_lds_, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, @@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); auto x_lds_val_buf = make_dynamic_buffer( - p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); + p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); auto in_thread_buf_tuple = generate_tuple( [&](auto) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21dac6f9e9eb59ae5af6bc81aea5e496231fc3b8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -0,0 +1,1053 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_scale_grid, + p_c_grid, + p_shared, + a_grid_desc, + b_grid_desc, + scale_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_scale_grid; + ignore = p_c_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = scale_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx11__)) +} + +// Assume B is Col-Major +template +struct GridwiseFpAintBGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // FIX ME: To be deprecated + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and Dequantized B: be careful of DataType + // scale would not put into LDS. + using LDS_ADataType = ADataType; + using LDS_BDataType = ADataType; + using LDS_CDataType = CShuffleDataType; + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + // B would be dequantize to ADataType before enter LDS + // b_lds_offset = LDS size allocated for a in byte / LDS_BDataType + static constexpr auto b_block_space_offset = + (a_block_space_offset + a_block_space_size_aligned) * sizeof(LDS_ADataType) / + sizeof(LDS_BDataType); + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(LDS_CDataType), + a_block_space_size_aligned * sizeof(LDS_ADataType) + + b_block_space_size_aligned * sizeof(LDS_BDataType)); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const ScaleGridDesc& scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + const auto scale_grid_buf = make_dynamic_buffer( + p_scale_grid, scale_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1_dequant, +/* typename BlockScaleSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ BBlockTransferThreadClusterLengths_K0_N_K1, +/* typename ThreadClusterArrangeOrder, */ BBlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ BDataType, +/* typename ScaleData, */ ScaleDataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(b_grid_desc), +/* typename ScaleDesc, */ decltype(scale_grid_desc), +/* typename DstDesc, */ decltype(b_block_desc), +/* typename SrcDimAccessOrder, */ BBlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ BBlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ BBlockTransferSrcScalarPerVector, +/* index_t ScaleScalarPerVector, */ 1, +/* index_t DstScalarPerVector, */ BBlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t ScaleScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ BThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + scale_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + ck::tensor_operation::element_wise::PassThrough{}, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + scale_grid_desc, + scale_grid_buf, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 3ced4b9ad61fe0be2783fb6f261b9131fb8d1622..35176591c1cfdd4ed0037b5baf24965660e4a0ee 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -67,7 +67,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index 27f48a84ba72fe284a964ac295540eac1c4f6766..5f6f2768ebc80f231df9b10d3d8a15ae99115088 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n) { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1d1bb6ed2d446fc61f4056acea275e048f002997..562b9b8ffa3b97d35932dba691f83a81bac3fe95 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,11 +7,9 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" @@ -19,8 +17,6 @@ namespace ck { -using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm; - template + bool HasDoubleTailKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -43,13 +38,6 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -// DPP8 is currently only supported on gfx1030 -#if !defined(__gfx1030__) - if(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return; - } -#endif constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -100,8 +88,7 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 c_grid_desc_m_n); } - template - __host__ __device__ static constexpr auto GetBlockwiseGemm() - { - if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - else - { - return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - } - using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = @@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - GetBlockwiseGemm(); + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); @@ -688,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n) { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2); const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2); const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b473d7cbf2b73070eacc38cce27b419e52adb179 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,701 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60c02d64e11cc8be94de6bd6aa8041f2d30b55a4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1038 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +// GEMM: +// input : A0[M, K], A1[M, K] +// input : B0[N, K], B1[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleABD_xdl_cshuffle +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_WORKAROUND_DENORM_FIX + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; +#else + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) + { + return generate_tuple( + [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, + Number{}); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) + { + return generate_tuple( + [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, + Number{}); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k, + const BsGridDesc_N_K& bs_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + const auto M = as_grid_desc_m_k[I0].GetLength(I0); + const auto N = bs_grid_desc_n_k[I0].GetLength(I0); + const auto AK = as_grid_desc_m_k[I0].GetLength(I1); + const auto BK = bs_grid_desc_n_k[I0].GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + bool valid = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + valid = + valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB); + valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) && + AK == as_grid_desc_m_k[i].GetLength(I1)); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + valid = + valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB); + valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) && + BK == bs_grid_desc_n_k[i].GetLength(I1)); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = AK / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + + if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride) + { + return generate_tuple( + [&](auto i) { + using ALayout = remove_cvref_t>; + + return MakeAGridDescriptor_M_K(MRaws[i], KRaws[i], AsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array& NRaws, + const std::array& KRaws, + const std::array& BsStride) + { + return generate_tuple( + [&](auto i) { + using BLayout = remove_cvref_t>; + + return MakeBGridDescriptor_N_K(NRaws[i], KRaws[i], BsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + gridwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const std::array StrideAs, + const std::array StrideBs, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + using AsGridDesc_M_K = + remove_cvref_t({}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t({}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + AsGridDesc_M_K as_grid_desc_m_k; + BsGridDesc_N_K bs_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumATensor, 1>{}([&](auto j) { + using ALayout = remove_cvref_t>; + + as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K(M, K, StrideAs[j]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto j) { + using BLayout = remove_cvref_t>; + + bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K(N, K, StrideBs[j]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); + + const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index d710fc18944ddb538ffb2d0385db862958e0542f..5c9f40b51a5ba28810352cf9dcf6d6e2dca1a404 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 98ade85a3089ccc5638fb45d4270488ebab59f3a..de6c9c1601ae82bf01a309ac67d8a8212121a656 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -36,7 +36,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( + kernel_grouped_conv_multiple_d_wmma_cshuffle( const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, @@ -45,8 +45,8 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -54,23 +54,22 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; DsPointer p_ds_grid_grp; @@ -85,8 +84,8 @@ __global__ void p_ds_grid_grp, p_e_grid + e_batch_offset, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock_, a_element_op, @@ -99,8 +98,8 @@ __global__ void ignore = p_ds_grid; ignore = p_e_grid; ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; ignore = a_element_op; @@ -116,8 +115,8 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -172,20 +169,16 @@ __global__ void DsPointer p_ds_grid_grp; - // printf("before allocate pointer d"); - static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - // printf("before entry"); - GridwiseOp::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_ds_grid_grp, p_e_grid + e_batch_offset, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -201,8 +194,8 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = block_2_etile_map; @@ -215,8 +208,8 @@ template (p_a_grid, p_b_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -266,15 +258,15 @@ __global__ void ignore = p_b_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__ )) +#endif // end of if (defined(__gfx11__ )) } template < // DataType Family @@ -285,8 +277,8 @@ template < // DataType Family typename DsDataType, typename EDataType, // InMemory Data Descriptor - typename AGridDesc_K0_M_K1, - typename BGridDesc_K0_N_K1, + typename AGridDesc, + typename BGridDesc, typename DsGridDesc_M_N, typename EGridDesc_M_N, // ElementwiseOp Family @@ -297,7 +289,7 @@ template < // DataType Family // Tiling Family index_t MPerBlock, index_t NPerBlock, - index_t K0PerBlock, + index_t KPerBlock, index_t MPerWmma, index_t NPerWmma, index_t K1Value, @@ -312,6 +304,7 @@ template < // DataType Family index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_K1, bool AThreadTransferSrcResetCoordinateAfterRun, + bool AEnableLds, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder, @@ -320,6 +313,7 @@ template < // DataType Family index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_K1, bool BThreadTransferSrcResetCoordinateAfterRun, + bool BEnableLds, bool BBlockLdsExtraN, index_t CShuffleMRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle, @@ -328,7 +322,7 @@ template < // DataType Family index_t NumGemmKPrefetchStage = 1, LoopScheduler LoopSched = make_default_loop_scheduler(), PipelineVersion PipelineVer = PipelineVersion::v1> -struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle +struct GridwiseGemmMultipleD_Wmma { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -344,68 +338,255 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle // K1 should be Number<...> static constexpr auto K1 = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< - decltype(GridwiseGemmPipeline_Selector())>; + using GridwiseGemmPipe = + remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() { - constexpr auto max_lds_align = K1; + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else { + constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return a_block_desc_k0perblock_mperblock_k1; + return b_block_desc; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() { - constexpr auto max_lds_align = K1; + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) + return make_multi_index(K0PerBlock, 0, 0); + } + else { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); } }(); - return b_block_desc_k0perblock_nperblock_k1; + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; } __host__ __device__ static constexpr auto // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } @@ -422,43 +603,100 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + // CheckValidity for kernels without multi D + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = - GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + return false; + } - constexpr auto b_block_desc_k0perblock_nperblock_k1 = - GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + return false; + } - constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; - constexpr auto max_lds_align = K1; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); - constexpr auto c_block_space_size_aligned = math::integer_least_multiple( - cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(), - max_lds_align); + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), - c_block_space_size_aligned * sizeof(CShuffleDataType)); + return true; } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const DsGridDesc_M_N& ds_grid_desc_m_n, - const EGridDesc_M_N& e_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -467,9 +705,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; bool valid = true; @@ -480,19 +748,25 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle if(!valid) { + printf("GridwiseOp: D descriptor dimension check failure\n"); return false; } if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); return false; + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); return false; + } // check gridwise gemm pipeline - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { @@ -507,8 +781,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -519,7 +793,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -565,6 +839,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle e_grid_desc_m_n); } + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -581,8 +886,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& @@ -592,14 +897,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const CDEElementwiseOperation& cde_element_op, const Block2CTileMap& block_2_ctile_map) { - // printf("safe entry"); // clang-format off /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + p_b_grid, b_grid_desc.GetElementSpaceSize()); const auto ds_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -625,13 +929,30 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle /*******************************************************************************/ // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - constexpr auto max_lds_align = K1; - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc.GetElementSpaceSize()); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -651,92 +972,189 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle /* index_t SrcScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, -/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( - a_grid_desc_k0_m_k1, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0perblock_mperblock_k1, + a_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0perblock_nperblock_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + b_block_desc.GetElementSpaceSize()); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; /*******************************************************************************/ // GEMM - constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; + BlockwiseGemmWMMA{}; // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); -/*******************************************************************************/ - constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - // LDS allocation for A and B: be careful of alignment - auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); - +/*******************************************************************************/ // Shift Per SUB_K - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); // gridwise GEMM pipeline - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0perblock_mperblock_k1, + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0perblock_nperblock_k1, + b_grid_desc, + b_block_desc, b_blockwise_copy, b_grid_buf, b_block_buf, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, - K0BlockMainLoop); + KBlockMainLoop); /*******************************************************************************/ // write out to C, implement shuffle { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 3b8a5ec8f50a400af8b317e534a32c8113a1018f..e6085fad8c8b13c541beb373114a58c22f6c1a60 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,9 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + namespace ck { // GEMM: @@ -26,7 +29,9 @@ namespace ck { // E = cde_op(C, D0, D1, ...) // Assume: // D0, D1, ... and E have the same layout -template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -92,15 +100,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - // denorm test fix, required to work around fp16 mfma issue - // we convert fp16->fp32->bf16 and execute bf16 mfma instruction - // when mfma if fixed, remove this section and update - // ABDataTypeAdjusted -> ABDataType throughout this file #if CK_WORKAROUND_DENORM_FIX - using ABDataTypeAdjusted = - conditional_t, ck::bhalf_t, ABDataType>; + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using ABDataTypeAdjusted = ABDataType; + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -169,8 +176,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -250,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle e_grid_desc_m_n); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static bool + CheckTensorTransfersValidity(index_t MRaw, index_t NRaw, index_t KRaw) + { + // Check if the vector dim is K1 or M|N + const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw; + const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw; + const auto E_vector_dim_size = NRaw; + + // check vector load for A tensor + if constexpr(is_same_v) + { + if(!(A_vector_dim_size == KRaw && + A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(A_vector_dim_size == MRaw && + A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0)) + return false; + } + else + { + return false; + } + + if constexpr(is_same_v) + { + if(!(B_vector_dim_size == NRaw && + B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(B_vector_dim_size == KRaw && + B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0)) + return false; + } + else + { + return false; + } + + if constexpr(is_same_v) + { + if(!(E_vector_dim_size == NRaw && + E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(E_vector_dim_size == NRaw && + CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) + return false; + } + else + { + return false; + } + + return true; + } + template {}([&](auto i) { @@ -297,24 +368,23 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // check gridwise gemm pipeline const auto num_k_loop = AK / KPerBlock; - if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } // check block-to-E-tile - if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) - { - return false; - } + // if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + //{ + // return false; + //} // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // check tensor size: cannot be larger than 2GB each constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && - b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -332,14 +402,102 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using DsGridPointer = decltype(MakeDsGridPointer()); + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + template - __device__ static void Run(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, @@ -408,8 +566,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + ADataType, + AComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -439,8 +597,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + BDataType, + BComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -468,13 +626,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataTypeAdjusted, + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -492,11 +652,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -761,6 +920,142 @@ struct GridwiseGemmMultipleD_xdl_cshuffle }); } } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K(K, N, StrideB); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_MK& a_grid_desc_m_k, + const BGridDesc_NK& b_grid_desc_n_k, + const DsGridDesc_MN& ds_grid_desc_m_n, + const EGridDesc_MN& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd36b9e51ae6db6c74f491b253ae7099a259d747 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -0,0 +1,1000 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_WORKAROUND_DENORM_FIX + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; +#else + using AComputeDataType = AComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, destination of blockwise copy. + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, destination of blockwise copy. + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment. + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle. + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max( + NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) + + NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // A desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy. + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using Block2ETileMap = remove_cvref_t; + + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + static_assert( + std::is_same_v && + std::is_same_v, + "Direct load transfers do not support elementwise operations other than passthrough."); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto AK = a_grid_desc_m_k.GetLength(I1); + const auto BK = b_grid_desc_n_k.GetLength(I1); + + // Check the consistency of descriptors. + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // Check the tile size. + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // Check gridwise gemm pipeline. + const auto num_k_loop = AK / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // Check block-to-E-tile. + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // Check tensor size: cannot exceed 2GB. + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // Divide block work by [M, N]. + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // This forces m/n_block_data_idx_on_grid into SGPR. + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, destination of blockwise copy. + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, destination of blockwise copy. + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ADataType, + AComputeDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcVectorDim, + 2, + ABlockTransferScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BDataType, + BComputeDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcVectorDim, + 2, + BBlockTransferScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment. + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // Shuffle C and write out. + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // Calculate the origin of thread output tensor on global memory. + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // Shuffle: threadwise copy C from VGPR to LDS. + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // A tuple of reference to C/Ds tensor descriptors. + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // A tuple of reference to C/Ds grid buffers. + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // A tuple of starting index of C/Ds blockwise copy. + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // Blockwise copy C/D/E between LDS and global. + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // Space filling curve for threadwise C in VGPR before shuffle. + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // Space filling curve for shuffled blockwise C/D/E. + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // Make sure it's safe to write to LDS. + block_sync_lds(); + + // Each thread write its data from VGPR to LDS. + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // Make sure it's safe to read from LDS. + block_sync_lds(); + + // Each block copy its data from LDS to global. + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // Move on Ds. + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // Move on E. + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + struct Argument : public tensor_operation::device::BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + if(CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // Tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise ops + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // For checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ae93af192e6768092b707b9bcc75138453b7c873 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -0,0 +1,1172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ALDSType) + + b_block_space_size_aligned * sizeof(BLDSType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ALDSType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BLDSType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ALDSType, + BLDSType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched, + AComputeType, + BComputeType>(); + + if constexpr(Zeroing) + { + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __builtin_amdgcn_s_barrier(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } + } + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if constexpr(Zeroing) + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + __builtin_amdgcn_s_barrier(); + } + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if constexpr(Zeroing) + { + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + } + + template + __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run, true>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t*, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + nullptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 7c11316cefa8058b5a127c316ddad5d6d7a5fad4..13aecd20e6fafa3661481a0918815bd545ce2ada 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,12 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp" #ifndef __HIPCC_RTC__ #include +#include #endif namespace ck { @@ -15,18 +17,23 @@ enum struct PipelineVersion { v1, v2, + // v3 is only used in the Stream-K implementation. + v4, + weight_only, }; template + LoopScheduler LoopSched = LoopScheduler::Default, + bool AEnableLds = true, + bool BEnableLds = true> constexpr auto GridwiseGemmPipeline_Selector() { if constexpr(PipelineVer == PipelineVersion::v1) { if constexpr(LoopSched == LoopScheduler::Default) { - return GridwiseGemmPipeline_v1{}; + return GridwiseGemmPipeline_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { @@ -37,6 +44,14 @@ constexpr auto GridwiseGemmPipeline_Selector() { return GridwiseGemmPipeline_v2{}; } + else if constexpr(PipelineVer == PipelineVersion::v4) + { + return GridwiseGemmPipeline_v4{}; + } + else if constexpr(PipelineVer == PipelineVersion::weight_only) + { + return GridwiseGemmPipeline_v1_WeightOnly{}; + } else { #ifndef __HIPCC_RTC__ @@ -46,3 +61,16 @@ constexpr auto GridwiseGemmPipeline_Selector() } } // namespace ck + +inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +{ + switch(p) + { + case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break; + case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break; + case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break; + case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break; + default: os << ""; + } + return os; +} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index d1209636de93a13cbe7f851575b4abaedd99af72..0cdb7ce2ca01bcc8a7afe594c044791891e4288f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -4,16 +4,17 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { -template +template struct GridwiseGemmPipeline_v1; // 1-stage prefetch template <> -struct GridwiseGemmPipeline_v1<1> +struct GridwiseGemmPipeline_v1<1, true, true> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -107,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1> // 2-stage prefetch template <> -struct GridwiseGemmPipeline_v1<2> +struct GridwiseGemmPipeline_v1<2, true, true> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -253,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2> } }; +template <> +struct GridwiseGemmPipeline_v1<1, false, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_block_buf = a_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, true, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, false, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_block_buf = a_block_buf_switch; + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template +struct GridwiseGemmPipeline_v1_WeightOnly; + +template <> +struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const ScaleGridDesc& scale_grid_desc, + const ScaleGridBuffer& scale_grid_buf, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // Global Prefetch Stage 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + // Scale read once + b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // Dequantization fused in blockwise_copy + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + template struct GridwiseGemmPipelineInterwave_v1; @@ -348,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1> // Note: 2 stage prefetch not optimized for inter-wave loop scheduler template <> -struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true> { }; @@ -358,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - return GridwiseGemmPipeline_v1{}; + return GridwiseGemmPipeline_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index d3d7d5af8577d1c2ed65fa817f708379c6089a25..25e1cebdbe06918cfea98132209116a2d74de822 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -9,13 +9,13 @@ namespace ck { struct GridwiseGemmPipeline_v2 { - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + __host__ __device__ static constexpr bool IsSupported(const index_t num_loop) { // TODO: improve applicability return num_loop % 2 == 0; } - __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop) { return (num_loop / 2) > 1; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08d986d0da62c01c970ee10ceab6b8e000392b46 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace lds_direct_load { + +__device__ void sched_barrier() +{ +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + // When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of + // `sched_barrier` is necessary to make sure no instructions that use the loaded memory + // are scheduled by the compiler before the `waitcnt` instruction. + __builtin_amdgcn_sched_barrier(0); +#endif +} + +} // namespace lds_direct_load + +namespace ck { + +template +struct GridwiseGemmPipeline_v4; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v4<1> +{ + static constexpr auto I0 = Number<0>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1); + auto& a_block_buf = a_block_bufs.At(I0); + auto& b_block_buf = b_block_bufs.At(I0); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stages prefetch +template <> +struct GridwiseGemmPipeline_v4<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2); + auto& a_block_buf1 = a_block_bufs.At(I0); + auto& a_block_buf2 = a_block_bufs.At(I1); + auto& b_block_buf1 = b_block_bufs.At(I0); + auto& b_block_buf2 = b_block_bufs.At(I1); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 99e410f688312b47dbeba4afb1d1712624561a76..0e5777e561476da30681b67f6d0d85a152a683f5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -55,7 +55,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 18cf80041b71c9f581462cef10f7fd3508e96930..0078660556418671d685c963ac3d25a8b82d087c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..caf8f040f4a78232386a1f0c3d0be02b4dc63295 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -0,0 +1,1076 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index d8b31311b1040f6c77e955550875a2a3e1337530..4458b9356dd64000997a93c23fc9241d1e27d8b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -18,11 +18,11 @@ namespace ck { template (p_a_grid, p_b_grid, p_c_grid, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, b_element_op, @@ -68,32 +63,32 @@ __global__ void ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } template -struct GridwiseGemm_k0mk1_k0nk1_mn_wmma +struct GridwiseGemm_Wmma { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -133,103 +130,289 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma static constexpr auto I6 = Number<6>{}; static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> + // FIX ME: To be deprecated static constexpr auto K1 = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< - decltype(GridwiseGemmPipeline_Selector())>; + using GridwiseGemmPipe = + remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return a_block_desc_k0perblock_mperblock_k1; + return a_block_desc; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + __host__ __device__ static constexpr auto MakeBBlockDescriptor() { - constexpr auto max_lds_align = K1; + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else { + + constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); } }(); - return b_block_desc_k0perblock_nperblock_k1; + return a_block_copy_step; } - __host__ __device__ static constexpr auto - // *Caution Here repeat is shuffle repeat - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = - GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = - GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif - constexpr auto max_lds_align = K1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + return a_wave_desc; + } - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); - return (a_block_space_size_aligned * sizeof(FloatA) + - b_block_space_size_aligned * sizeof(FloatB)); + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -238,23 +421,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); return false; + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); return false; + } // check gridwise gemm pipeline - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { + printf("GridwiseOp err: Pipeline not support this k_loop"); return false; } @@ -266,8 +492,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB && - b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB)) + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) { return false; } @@ -276,7 +502,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -308,6 +534,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma c_grid_desc_m_n); } + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -315,12 +572,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma remove_cvref_t; template - __device__ static void Run(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation& a_element_op, @@ -332,9 +589,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + p_b_grid, b_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -352,24 +609,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); /*******************************************************************************/ -// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - constexpr auto max_lds_align = K1; - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, -/* typename SrcData, */ FloatA, -/* typename DstData, */ FloatA, -/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), -/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -379,99 +653,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* index_t SrcScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, -/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( - a_grid_desc_k0_m_k1, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0perblock_mperblock_k1, + a_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatB, - FloatB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0perblock_nperblock_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; /*******************************************************************************/ // GEMM - constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; + BlockwiseGemmWMMA{}; // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); -/*******************************************************************************/ - constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - // LDS allocation for A and B: be careful of alignment - auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); - +/*******************************************************************************/ // Shift Per SUB_K - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); // gridwise GEMM pipeline - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0perblock_mperblock_k1, + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0perblock_nperblock_k1, + b_grid_desc, + b_block_desc, b_blockwise_copy, b_grid_buf, b_block_buf, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, - K0BlockMainLoop); + KBlockMainLoop); /*******************************************************************************/ // write out to C, implement shuffle { + // C mapping in single thread. constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - // This API Provide All dimension (size) you need + // C mapping in single block constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); @@ -486,8 +858,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, @@ -533,8 +905,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, // BlockSliceLengths, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, @@ -637,6 +1009,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma if constexpr(access_id < num_access - 1) { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2a06ba9afe5b89be9b19c3bb1183d7b21c219fb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp @@ -0,0 +1,1369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff10215353cf1b47d598e7dc28bd849ee139e90b --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,2010 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_streamk_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + Streamk_sel{Streamk_sel_}, + Grid_size{Grid_size_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, 1)}, + KPadded{CalculateKPadded(K_, 1)}, + AK0{CalculateAK0Padded(K_, 1)}, + BK0{CalculateBK0Padded(K_, 1)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel + << ", Grid size:" << Grid_size << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t Streamk_sel; + mutable index_t Grid_size; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Problem& problem, unsigned int kbatch_id, unsigned int orig_K) + { + if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead * problem.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead * problem.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead; + } + + if(kbatch_id < static_cast(problem.KBatch - 1)) + { + problem.K = problem.KRead; + } + else + { + problem.K = orig_K - problem.KRead * (problem.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + + if(karg.K <= 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(is_same, bhalf_t>::value) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk( + problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + { + + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 9c09f3a53912c1fb703292c6ae6be7782fc1d7b1..c293d64ef0d983c215c6b3c7d005d37b461c9eb6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -25,7 +25,7 @@ __global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -50,7 +50,7 @@ __global__ void typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); @@ -108,7 +108,8 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ComputeType) + - b_block_space_size_aligned * sizeof(ComputeType)), + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), c_block_size * sizeof(FloatCShuffle)); } @@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + ComputeTypeA, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + ComputeTypeB, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, + ComputeTypeA, + ComputeTypeB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..db9625c6e6032114b87209bfc47048f4c6d01b86 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,1153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = problem; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(num_k_loop < 4) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return num_loop > 3; + } + + __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + if(num_loop % 2 == 1) + return 3; + else + return 2; + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } +#if 0 + if(threadIdx.x == 0){ + printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n", + get_block_1d_id(), + block_work_idx[I0], + block_work_idx[I1], + __smid()>>6 & 0xf, + __smid()>>4 & 0x3, + __smid() & 0xf); + } +#endif + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + ComputeTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + ComputeTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + // BlockSize, + // ComputeType, + // FloatGemmAcc, + // decltype(a_block_desc_ak0_m_ak1), + // decltype(b_block_desc_bk0_n_bk1), + // MPerXdl, + // NPerXdl, + // MXdlPerWave, + // NXdlPerWave, + // KPack, + // LoopSched>(); + auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4< + BlockSize, + ComputeTypeA, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>{}; // TransposeC + + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..36797a906a267cb4a08bf19bc00efa48e1c7e945 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -0,0 +1,1987 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9071bd29d212d0ee8310d9efd3ca84b91a0ddf0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -0,0 +1,2491 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + using LDSTypeA = ComputeTypeA; + using LDSTypeB = ComputeTypeB; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto + MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + const index_t MPad, + const index_t K, + const index_t KPad, + const std::array& StrideAs, + const index_t AK0) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0); + }, + Number{}); + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + const index_t KPad, + const index_t N, + const index_t NPad, + const std::array& StrideBs, + const index_t BK0) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0); + }, + Number{}); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); }, + Number{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideAs{StrideAs_}, + StrideBs{StrideBs_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + index_t StrideC; + + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, + std::array p_ds_grid_, + void* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideC_, k_batch_}, + p_as_grid{}, + p_bs_grid{}, + p_ds_grid{}, + p_c_grid{static_cast(p_c_grid_)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + + // A pointer + p_as_grid(i) = static_cast(p_as_grid_[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + + // B pointer + p_bs_grid(i) = static_cast(p_bs_grid_[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + AsGridPointer p_as_grid; + BsGridPointer p_bs_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + +#if 0 + struct SplitKBatchOffsetMultiABD + { + __device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + Argument& karg) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout_ = remove_cvref_t>; + if constexpr(is_same_v) + { + as_k_split_offset[i] = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i]; + } + + p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i]; + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout_ = remove_cvref_t>; + if constexpr(is_same_v) + { + bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i]; + } + else if constexpr(is_same_v) + { + bs_k_split_offset[i] = blockIdx.z * karg.KRead; + } + + p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i]; + }); + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + AsGridPointer p_as_grid_; + BsGridPointer p_bs_grid_; + std::array as_k_split_offset; + std::array bs_k_split_offset; + }; +#endif + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + { + // std::array StrideAs = {problem.StrideA}; + // std::array StrideBs = {problem.StrideB}; + + // AsGridPointer p_as_grid; + // BsGridPointer p_bs_grid; + // DsGridPointer p_ds_grid; + + // const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + // problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + // const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + // problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + +#if 0 + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_m_n(j) = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]); + }); +#endif + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // const auto a_grid_buf = make_dynamic_buffer( + // p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + // const auto b_grid_buf = make_dynamic_buffer( + // p_bs_grid[I0], b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + +#if 0 + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; +#endif + +#if 0 + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + +#endif + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + +#if 0 + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; +#else + using EDataType = CDataType; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock = + CShuffleBlockTransferScalarPerVector_NPerBlock; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + +#endif + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); +#if 0 + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + +#else + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); +#endif + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + +#if 0 + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } +#else + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } +#endif + }); + } + } + +#if 1 + template + __device__ static void Run_2Lds(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + { + // const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + // problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + // const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + // problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // const auto a_grid_buf = make_dynamic_buffer( + // p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + // const auto b_grid_buf = make_dynamic_buffer( + // p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + +#if 0 + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + +#endif + +#if 0 + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; +#endif + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + +#if 0 + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; +#else + using EDataType = CDataType; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock = + CShuffleBlockTransferScalarPerVector_NPerBlock; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + c_element_op}; + +#endif + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + +#if 1 + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; +#endif + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + +#if 0 + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } +#else + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } +#endif + }); + } + } +#endif +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7038ed4fa358198e447ae4afec7f7291bd46358 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -0,0 +1,2144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemmMultiD_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run_2Lds( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..da6b1b304e58b06b3399208af222ce0ef81d277f --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -0,0 +1,1694 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 +{ + using AScaleType = float; + using BScaleType = float; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + const AScaleType* p_a_scale_grid_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AScaleType* p_a_scale_grid; + const BScaleType* p_b_scale_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + const index_t ScaleSliceSizeM = 1; + const index_t ScaleSliceSizeN = 1; + const index_t ScaleSliceSizeK = 1; + + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + 1, + 1, + false>( + a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0)); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + 1, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1); + + const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 0404d88ab8b4faebbe8613453cb4274cf631fe2b..7f815de1f9d3ffc145664e8781d6f7af20cbcde1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -58,7 +58,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue @@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index bbd01a238e5d5f5d071ddec519a08d68e8fef206..8675a9242acf0d573cc0316f3f89b95b654b38af 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, ABDataType, + ABDataType, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 0920a17fc7f9cfa6265f1599d6deade702fd53ed..5617f67f8be20f9d775f157c6e24804c4a906241 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen } template (p_a_grid, @@ -181,21 +182,22 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight { static constexpr auto I0 = Number<0>{}; @@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update - // FloatABAdjusted -> FloatAB throughout this file + // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, + // throughout this file #if CK_WORKAROUND_DENORM_FIX - using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; + using FloatAAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeA>; + using FloatBAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeB>; #else - using FloatABAdjusted = FloatAB; + using FloatAAdjusted = ComputeTypeA; + using FloatBAdjusted = ComputeTypeB; #endif // M0/M1/M1Padding @@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size * sizeof(FloatAAdjusted) + + b_block_space_size * sizeof(FloatBAdjusted)), c_block_size * sizeof(FloatC)); } @@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, @@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatA, + FloatAAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatB, + FloatBAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -733,11 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size, + static_cast(p_shared) + a_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index b12bcee0f414c6e56b6550c084630b86ab153bb5..7c401a49572a6f72c0e3a4c9d46d4bc4515f73cd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -45,7 +45,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bac8c3288611001ee154e2b500ba64ff9979f909 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -0,0 +1,961 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); +#else + ignore = karg; + ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdlops_splitk_lds_direct_load +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0Padded; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0Padded(K0Padded_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0Padded:" << K0Padded << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) * + sizeof(ComputeType), + c_block_size * sizeof(FloatC)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + + const auto num_k_loop = karg.K0Padded / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; + return KPad; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) + { + const index_t num_loop = K0Padded / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; + + template + __device__ static void Run(const Argument& karg, + void* __restrict__ p_shared_block, + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [KBatch, M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_K0_M_K1, + FloatA, + ComputeType, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_K0_N_K1, + FloatB, + ComputeType, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, // ComputeType A + ComputeType, // ComputeType B + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index 4408b34870e93b2e681160fee12e0743269ff6f7..e9190dee292064fb1b5b31d249c88527718768ea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -37,7 +37,8 @@ __global__ void index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -489,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1(p_a_grid, @@ -70,7 +70,7 @@ __global__ void kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const auto a_grid_desc_k0_m_k1 = @@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; } - __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); } // Argument struct Problem @@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 StrideC{StrideC_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, - K0{CalculateK0(K)} + K0{CalculateK0(K_)} { } @@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 "Invalid tuning param!"); // check gridwise gemm pipeline - const index_t K0 = problem.K / K1Value; - const auto num_k_loop = K0 / K0PerBlock; - + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; @@ -383,7 +381,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -426,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { return transform_tensor_descriptor( a_grid_desc_m_k, @@ -874,7 +892,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { return transform_tensor_descriptor( b_grid_desc_k_n, @@ -908,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { return transform_tensor_descriptor(c_grid_desc_m_n, make_tuple(make_right_pad_transform(M, MPad - M), @@ -958,6 +996,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.K0 % K0PerBlock == 0)) + { + return false; + } + } + if constexpr(is_same::value) { if(problem.K % ABlockTransferSrcScalarPerVector != 0) @@ -989,8 +1038,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } // check gridwise gemm pipeline - const index_t K0 = problem.K / K1; - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 19fbee727f77ae8b6b6f90a5a6f520accde95332..7d8e94c00166cb3933cdb9bef00dc7a352b7a8b8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -43,7 +43,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 371281dfe20dddacf861677a771c3a425028a27f..6ee279a3f1c5cc4f2fe2d2d0c4408b7be2e3c05d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -22,25 +21,34 @@ namespace ck { template + typename Block2CTileMap, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map) + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map); + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); #else ignore = karg; ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -87,7 +95,10 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA, + typename LDSTypeA = ComputeTypeA, + typename LDSTypeB = ComputeTypeB> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -127,7 +138,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t MPadded; index_t NPadded; index_t KPadded; - index_t K0; + index_t K0Padded; index_t k_batch; Argument(const FloatA* p_a_grid_, @@ -142,7 +153,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t MPadded_, index_t NPadded_, index_t KPadded_, - index_t K0_, + index_t K0Padded_, index_t k_batch_) : p_a_grid(p_a_grid_), p_b_grid(p_b_grid_), @@ -156,7 +167,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 MPadded(MPadded_), NPadded(NPadded_), KPadded(KPadded_), - K0(K0_), + K0Padded(K0Padded_), k_batch(k_batch_) { } @@ -173,7 +184,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KP:" << KPadded << ", " - << "K0:" << K0 << ", " + << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; @@ -196,7 +207,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 return math::integer_least_multiple(N, NPerBlock); } - __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) { // k_batch * k0 * k0_per_block * k1 auto K_t = K_Batch * K0PerBlock * K1; @@ -205,8 +216,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) { - auto K0 = CalculateK0(K, K_Batch); - return K_Batch * K0 * K1; + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; } __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, @@ -214,7 +225,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t K, index_t StrideA, index_t KBatch, - index_t K0, + index_t K0Padded, index_t KPad) { const auto a_grid_desc_m_k = [&]() { @@ -228,30 +239,57 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } }(); - const auto a_grid_desc_m_kpad = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; return transform_tensor_descriptor( a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_right_pad_transform(M, MPad - M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } - else + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) { + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return transform_tensor_descriptor( a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -263,7 +301,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t N, index_t StrideB, index_t KBatch, - index_t K0, + index_t K0Padded, index_t KPad) { const auto b_grid_desc_k_n = [&]() { @@ -277,30 +315,57 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } }(); - const auto b_grid_desc_kpad_n = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; return transform_tensor_descriptor( b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_right_pad_transform(N, NPad - N)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } - else + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) { + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return transform_tensor_descriptor( b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_pass_through_transform(N)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -367,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType), + return math::max(a_block_space_size * sizeof(LDSTypeA) + + b_block_space_size * sizeof(LDSTypeB), c_block_size * sizeof(FloatC)); } @@ -380,15 +446,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(!(karg.M % MPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || @@ -396,12 +463,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(!(karg.N % NPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } -#endif // DEBUG_LOG + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -410,13 +496,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -424,13 +510,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -439,13 +525,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -453,13 +539,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -468,14 +554,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { -#if DEBUG_LOG - std::cout - << "Arg N (" << karg.N - << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" - << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -483,27 +569,29 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { -#if DEBUG_LOG - std::cout - << "Arg M (" << karg.M - << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" - << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } - const auto num_k_loop = karg.K0 / K0PerBlock; + const auto num_k_loop = karg.K0Padded / K0PerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { -#if DEBUG_LOG - std::cout << "The number of k loops (" << num_k_loop - << ") value is not supported by GridwiseGemm Pipeline." - << " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The number of k loops (" << num_k_loop + << ") value is not supported by GridwiseGemm Pipeline." + << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -512,14 +600,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) { - const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; - const index_t KPad = KBatch * K0 * K1; + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; return KPad; } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) { - const index_t num_loop = K0 / K0PerBlock; + const index_t num_loop = K0Padded / K0PerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -577,22 +666,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 typename Block2CTileMap> __device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block, - const Block2CTileMap& block_2_ctile_map) + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) { const FloatA* p_a_grid = karg.p_a_grid; const FloatB* p_b_grid = karg.p_b_grid; FloatC* p_c_grid = karg.p_c_grid; const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( - karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( - karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - const AElementwiseOperation a_element_op = AElementwiseOperation{}; - const BElementwiseOperation b_element_op = BElementwiseOperation{}; - const CElementwiseOperation c_element_op = CElementwiseOperation{}; const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); @@ -701,7 +790,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + LDSTypeA, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -731,7 +820,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + LDSTypeB, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -761,7 +850,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, + LDSTypeA, + LDSTypeB, FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), @@ -770,7 +860,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 MRepeat, NRepeat, K1, - LoopSched>(); + LoopSched, + ComputeTypeA, + ComputeTypeB>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -778,8 +870,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - ComputeType* p_a_block = static_cast(p_shared_block); - ComputeType* p_b_block = static_cast(p_shared_block) + a_block_space_size; + auto p_a_block = reinterpret_cast(p_shared_block); + auto p_b_block = reinterpret_cast(p_a_block + a_block_space_size); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8d7bd9a8d1eeae302e97fd465ba1678eac1b8bba..15c64f2e474272ac05f5a901e9bb8899103875d7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -47,7 +47,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( @@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( @@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) + GridwiseTensorRearrangeKernel::Run(in_grid_desc, + p_in_global, + out_grid_desc, + p_out_global, + batch_count, + block_2_tile_map, + compute_ptr_offset_of_batch); +#else + ignore = in_grid_desc; + ignore = p_in_global; + ignore = out_grid_desc; + ignore = p_out_global; + ignore = batch_count; + ignore = block_2_tile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct GridwiseTensorRearrange +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThisThreadBlock = ThisThreadBlock; + + __device__ static void Run(const InputGridDesc& in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc& out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap& block_2_tile_map, + const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch) + { + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); + + auto copy_global_to_global = + ThreadGroupTensorSliceTransfer_v7, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(DstInMemOp)>, + Sequence, + ThreadClusterLengths, + Sequence<0, 1>, + Sequence<0, 1>, + I1, + ScalarPerVector, + Sequence, + Sequence>{ + in_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + out_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + tensor_operation::element_wise::PassThrough{}}; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + // Global Memory + const index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + const auto in_global_buf = make_dynamic_buffer( + p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize()); + auto out_global_buf = make_dynamic_buffer( + p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize()); + + copy_global_to_global.Run( + tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); + } + + __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc, + const OutputGridDesc& out_grid_desc) + { + if(in_grid_desc.GetLength(I0) % MPerBlock != 0 || + in_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + if(out_grid_desc.GetLength(I0) % MPerBlock != 0 || + out_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + return true; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a0e16d7f6e03cda523fc298070443154d97e992 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// Tensor Shape +// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1] + +// Flow: +// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size): +// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True) +// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True) +// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size +// c = -b * x_mean - db * inv_std / reduce_size +// dx = inv_std * dy * gamma + b * x + c +// return dx + +template +struct GridwiseNormalizationBwdData_mk_to_mk +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) || + (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) || + (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using GammaThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using DXThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M_K& dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_val_buf = make_dynamic_buffer( + p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto gamma_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dx_thread_buf = StaticBuffer{}; + + auto ds_thread_buf = + StaticBuffer{}; + + auto db_thread_buf = + StaticBuffer{}; + + // thread id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + ComputeDataType reduce_size = type_convert( + dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + ds_thread_buf(I) = type_convert(0.0f); + db_thread_buf(I) = type_convert(0.0f); + }); + + // Separate sweep once and sweep twice pipeline + // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K + // we don't need to use loop to read x, dy, gamma twice + if constexpr(SweepOnce) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + } // end of sweep once + else // Sweep Twice pipeline + { + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + } // end of first sweep + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + // reverse read for using dy, gamma and x in the cache + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + // move to tail + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); + + // move from start to tail + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21248e3a0a805d81756fd3b2ab4dfe4e73a85915 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// dgamma = reduce_sum(dy * (x - mean) * inv_std) +// dbeta = reduce_sum(dy) +template +struct GridwiseNormalizationBwdGammaBeta_mk_to_k +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + // do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize, + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M& dgamma_grid_desc_m, + const GridDesc_M& dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dgamma_global_val_buf = make_dynamic_buffer( + p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize()); + + auto dbeta_global_val_buf = make_dynamic_buffer( + p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dgamma_thread_buf = + StaticBuffer{}; + + auto dbeta_thread_buf = + StaticBuffer{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dgamma_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DGammaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dgamma_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dbeta_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DBetaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dbeta_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dgamma_thread_buf(I) = type_convert(0.0f); + dbeta_thread_buf(I) = type_convert(0.0f); + }); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + dgamma_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + (x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]); + + dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k]; + }); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_dgamma_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dgamma_thread_buf, + dgamma_grid_desc_m, + dgamma_global_val_buf); + + threadwise_dbeta_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbeta_thread_buf, + dbeta_grid_desc_m, + dbeta_global_val_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index c3f122106df29a471e376c453d055806c9aa2514..1bee1b93f9763efd6d5a92fddf65fbef8c62cb5c 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -18,9 +18,11 @@ template struct GridwiseNormalizationNaiveVariance_mk_to_mk { @@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk reduce::Add, true>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // LDS @@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer& var_thread_buf = mean_square_thread_buf; + StaticBuffer& + inv_std_thread_buf = mean_square_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // E(x), E[x^2], var(x) // FIXME: Should not hack the transform from deviceOP - int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + ComputeDataType reduce_length = type_convert( + x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); @@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index e50fb9813325b08576f49d321e7a5de5d51bbff2..8157b4fbc39cc1ce9a3958ce2036e094b0d7bdcd 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -12,31 +12,42 @@ template -__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - ComputeDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const YElementwiseOperation y_elementwise_op) + typename GridDesc_M_K, + typename GridDesc_M> +__global__ void +kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + const GridDesc_M save_mean_grid_desc_m, + const GridDesc_M save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) { GridwiseReduction::Run(x_grid_desc_m_k, gamma_grid_desc_m_k, beta_grid_desc_m_k, y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, num_k_block_tile_iteration, epsilon, p_x_global, p_gamma_global, p_beta_global, p_y_global, + p_save_mean_global, + p_save_inv_std_global, y_elementwise_op); }; @@ -44,9 +55,11 @@ template auto NormalizationKernelSelector(bool isSweepOnce) { @@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, GridDesc_M_K, + GridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, @@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) BetaSrcVectorSize, YDstVectorDim, YDstVectorSize, + SaveMeanInvStdDstVectorSize, false>; using GridwiseNormalizationSweepOnceNaive = GridwiseNormalizationNaiveVariance_mk_to_mk; using GridwiseNormalizationGenericWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; using GridwiseNormalizationSweepOnceWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; if constexpr(UseWelford) @@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } else { @@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } } diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index 136ac94e7f0fabb81cfd5baea7e58174799f4d43..9e380f963801460e93684ba4a5cf3244c3fef1cb 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -17,11 +17,13 @@ template + index_t YDstVectorSize, + index_t SaveMeanInvStdDstVectorSize> struct GridwiseNormalizationSplitK2nd { static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || @@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadBufferLengths_M_1 = Sequence; static constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); @@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, @@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // Thread/Block id @@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + // VGPR StaticBuffer in_mean_thread_buf; @@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd var_thread_buf; StaticBuffer welford_count_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; auto x_thread_buf = generate_tuple( [&](auto) { @@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + // step1: Merge mean and variance constexpr auto mean_var_count_thread_copy_step_I0_k = make_multi_index(I0, KThreadClusterSize); @@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd BlockwiseWelford::Run( mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + + inv_std_thread_buf(I) = + type_convert(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon); }); - // step2: normalization + // step2: save mean and inverse std for backward (optional) + if(block_k_cluster_id == 0 && thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // step3: normalization constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); for(index_t k = 0; k < num_k_block_tile_iteration; ++k) @@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index ff9712276c26d14c9a68f9df618ed4666f2dc5e9..15b412fba4cc0dd3d33e6b9a4f663c94a2053bda 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -16,9 +16,11 @@ template struct GridwiseNormalizationWelfordVariance_mk_to_mk { @@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ThreadClusterLengths_M_K, ThreadClusterArrangeOrder>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer var_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const auto beta_global_val_buf = make_dynamic_buffer( p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto threadwise_welford = ThreadwiseWelford(); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); @@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp deleted file mode 100644 index d0d214381d14e64f91805ca2d605d78f66470c3f..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/utility/inner_product_dpp8.hpp" -#include "ck/utility/math.hpp" - -namespace ck { - -/** - * Threadwise contraction using dot instructions with DPP8 modifier. - * - * Assumptions: - * 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1` - * are known at compile-time; - * 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time; - * 3. `TM0` is equal to 1 and `TN0` is equal to 1; - * 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by - * the size of the lane group (`dpp8::lane_group_size`). - */ -template ::type = false> -struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 -{ - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t TK0 = TKLengths{}[I0]; - static constexpr index_t TK1 = TKLengths{}[I1]; - static constexpr index_t TM0 = TMLengths{}[I0]; - static constexpr index_t TM1 = TMLengths{}[I1]; - static constexpr index_t TN0 = TNLengths{}[I0]; - static constexpr index_t TN1 = TNLengths{}[I1]; - - static_assert(TM0 == 1 && TN0 == 1); - - static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) || - (!ShareA && TN1 % dpp8::lane_group_size == 0)); - static constexpr index_t shared_elems_per_lane = - ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size; - - __device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() - { - static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && - BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && - CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, - "wrong!"); - } - - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, TK0, 1>{}([&](auto tk0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; - - static_for<0, TK1, 1>{}([&](auto tk1) { - constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1; - constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1)); - - constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane; - constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1)); - - a_vec.template AsType()(tk1) = a_buf[Number{}]; - b_vec.template AsType()(tk1) = b_buf[Number{}]; - }); - - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; - - constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(0, tm1, 0, tn1)); - - constexpr int src_lane = - ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size - : (tn1 / shared_elems_per_lane) % dpp8::lane_group_size; - - dpp8::inner_product_dpp( - a_vec.template AsType()[I0], - b_vec.template AsType()[I0], - c_buf(Number{})); - }); - }); - }); - } -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 605f2569c6622bb589d46da71020ea46ddba69f7..d7a6a3624410dea9777c7cf4299ff9c525696194 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -8,38 +8,11 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -namespace ck { +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory -// and sometimes useless instructions: -// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument -// instead -// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same -// tensor coordinate instead -// 3. Don't use a pointer to VGPR buffer, use vector instead - -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - return (i == VectorDim) ? ScalarPerVector : 1; - } -}; - -template -struct lambda_scalar_step_in_vector -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - return (i == VectorDim) ? 1 : 0; - } -}; -} // namespace detail +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" +namespace ck { // Assume: // 1. src: // 1. SrcDesc is known at compile-time @@ -137,13 +110,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); - // apply type convert - dst_vector.template AsType()(i) = type_convert(v); + dst_vector.template AsType()(i) = v; }); const bool is_dst_valid = @@ -1157,27 +1129,56 @@ struct ThreadwiseTensorSliceTransfer_v4 src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - // apply type convert src_tmp_vector.template AsType()(i) = src_buf[Number{}]; }); } - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); - }); + if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 2; + + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } }); } @@ -1289,13 +1290,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); // apply type convert - dst_buf(Number{}) = type_convert(v); + dst_buf(Number{}) = v; }); }); } @@ -1303,4 +1304,246 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; +// Specilized for WMMA-Navi3 +// A single Wave32 is composed by double row +// Data exchange allowed between these two rows +// This RowLane Dst buf will be filled from two Src buf +// SrcA: From specific thread buffer hold by This RowLane on This Row +// SrcB: From specific thread buffer hold by This RowLane on The other Row +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row, v_theother_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply inter-row permute. + temp = __builtin_amdgcn_permlanex16(temp, + type_convert_sp(v_this_row), + LowEightRowlaneIdx, + HighEightRowLaneIdx, + 1, + 0); + v_theother_row = type_convert_sp(temp); + + if(get_thread_local_1d_id() % 32 < 16) + { + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + dst_buf(Number{}) = + type_convert_sp(v_theother_row); + } + else + { + // apply type convert + dst_buf(Number{}) = + type_convert_sp(v_this_row); + dst_buf(Number{}) = type_convert_sp(v_theother_row); + } + }); + }); + } + ElementwiseOperation element_op_{}; +}; + +// Specilized for WMMA-Navi4 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..96b95579f524613c72b49e6a3377913e317b214b --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 32ea8ae39c60c0e2130ab9381d94ecd8ea399982..96ea04c8fa12bffa868b5dd6f549d6e43fc9fbf7 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,42 +7,12 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" -namespace ck { - -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access_for_src_and_dst -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - if(i == SrcVectorDim && i == DstVectorDim) - { - return math::lcm(SrcScalarPerVector, DstScalarPerVector); - } - else if(i == SrcVectorDim) - { - return SrcScalarPerVector; - } - else if(i == DstVectorDim) - { - return DstScalarPerVector; - } - else - { - return 1; - } - } -}; +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" -} // namespace detail +namespace ck { // Assume: // 1. src_desc and dst_desc are not known at compile-time @@ -201,20 +171,56 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); + // maintain a container record is_src_valid, waiting for RunWrite use. const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + src_oob_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, is_src_valid); using src_vector_type = vector_type_maker_t; using src_vector_t = typename src_vector_type::type; - // copy data from src_buf into src_vector_container - auto src_vector_container = src_vector_type{ - src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + auto src_vector_container = + src_vector_type{src_buf.template Get(src_coord_.GetOffset(), true)}; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if needed + src_element_op_(op_r_v.template AsType()(idx), + src_vector_container.template AsType()[idx]); + }); // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + .template SetAsType(src_data_idx_seq, + op_r_v.template AsType()[I0]); constexpr auto move_on_dim = [&]() constexpr { @@ -267,19 +273,81 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = - type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else + + // OOB Check + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using vector_t = typename vector_type_maker::type::type; + + auto op_r = src_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + auto op_r_v = is_src_valid ? op_r : vector_t(0); + + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, op_r_v); + }); + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // TODO make this logic more generic for more sub-dword datatype if constexpr(SrcVectorDim != DstVectorDim && - ((is_same>::value && - is_same>::value && + ((is_same>::value && SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || - (is_same>::value && - is_same>::value && + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { // each transpose does @@ -313,7 +381,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - using src_vector_t = vector_type_maker_t; + using src_vector_t = vector_type_maker_t; using dst_vector_t = vector_type_maker_t; // get DstScalarPerVector # of read-only references to src vectors from @@ -336,17 +404,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number{}); // do data transpose - transpose_vectors{}( + transpose_vectors{}( src_vector_refs, dst_vector_refs); }); } - - static_ford{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if needed - DstData dst_v; - src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); - dst_thread_scratch_(idx) = dst_v; - }); + else + { + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; + }); + } #endif } @@ -356,6 +423,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number thread_scratch_id = Number{}) { // if there is transpose, it's done here + // if there is oob check, it's done here // TODO move this elsewhere TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); @@ -708,6 +776,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + return make_naive_tensor_descriptor_packed(src_access_lengths); + } + __device__ static constexpr auto GetDstThreadScratchDescriptor() { // 1st stage of transforms @@ -759,13 +837,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1 private: static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer; + + using SrcOOBThreadScratch = + StaticTensorTupleOfVectorBuffer; using DstThreadScratch = StaticTensorTupleOfVectorBuffer; StaticallyIndexedArray src_thread_scratch_tuple_; + StaticallyIndexedArray src_oob_thread_scratch_tuple_; DstThreadScratch dst_thread_scratch_; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..174b82f8700f17052827a63d3d36cee73736803f --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst_idle +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +// 5. Dequantization happened between read and write. +template +struct ThreadwiseTensorSliceTransfer_v3r1_dequant +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + scale_element_op_(scale_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc, + const Index& scale_slice_origin_idx) + { + scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! ScaleBuffer and ScaleData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_scale_access_lengths = + container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order); + + // make forward steps + const auto scale_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto scale_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_scale_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_scale_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate scale data index + constexpr auto scale_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i] + : ordered_scale_access_lengths[i] - 1 - + ordered_scale_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) * + scale_scalar_per_access; + }(); + + constexpr auto scale_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + scale_desc, scale_coord_); + + using scale_vector_type = vector_type_maker_t; + using scale_vector_t = typename scale_vector_type::type; + + // copy data from scale_buf into scale_vector_container + auto scale_vector_container = scale_vector_type{ + scale_buf.template Get(scale_coord_.GetOffset(), is_scale_valid)}; + + // copy data from scale_vector_container into scale_thread_scratch_ + scale_thread_scratch_.template SetAsType( + scale_data_idx_seq, scale_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = + ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move scale coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_forward_steps[scale_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_backward_steps[scale_dim_access_order[i]]); + } + } + }); + }); + + // don't need to move scale coordinate back to slice origin + /* + if constexpr(SrcResetCoordinateAfterRun) + { + const auto scale_reset_step = + make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep()); + + move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step); + } + */ + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + + // Do fast numeric convert + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using src_converted_vector_type = vector_type_maker_t; + using src_converted_vector_t = typename src_converted_vector_type::type; + // Vector-wise type convert + static_ford{}([&](auto access_idx) { + auto src_vector_container = src_vector_type{ + src_thread_scratch_tuple_[thread_scratch_id].template GetAsType( + access_idx)}; + + auto src_converted_vector_container = + src_converted_vector_type{fast_numeric_converter(src_vector_container)}; + + src_converted_thread_scratch_.template SetAsType( + access_idx, + src_converted_vector_container.template AsType()[I0]); + }); + + // Element-scale operation, expect packed multiplication + static_ford{}([&](auto idx) { + DstData dst_v; + constexpr auto scale_idx = Sequence{}; + // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(), + // *(reinterpret_cast(&scale_thread_scratch_[scale_idx]))); + src_element_op_(dst_v, + src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); + dst_thread_scratch_(idx) = dst_v; + }); +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetScaleThreadScratchDescriptor() + { + + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(scale_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(scale_access_lengths_and_vector_length[i], + scale_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(scale_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto scale_thread_scratch_desc_ = + decltype(GetScaleThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + /* + template + struct ScaleThreadScratchDesc{}; + */ + + // Registers, contain raw data loaded from global buffer + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain fast converted data + using SrcThreadConvertedScratch = + StaticTensorTupleOfVectorBuffer; + + // Registers, contain scale data + using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain dequantized data + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + using FastTypeConverter = tensor_operation::element_wise:: + FastNumericArrayConverter; + + StaticallyIndexedArray src_thread_scratch_tuple_; + SrcThreadConvertedScratch src_converted_thread_scratch_; + ScaleThreadScratch scale_thread_scratch_; + + DstThreadScratch dst_thread_scratch_; + FastTypeConverter fast_numeric_converter; + + SrcCoord src_coord_; + ScaleCoord scale_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const ScaleElementwiseOperation scale_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f0d793456d335409da53aa9bec9191b6a8adfa2b --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + src_coords_(src_i) = + make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + dst_coords_(dst_i) = + make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]); + }); + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access_tuple = generate_tuple( + [&](auto src_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return SliceLengths{} / src_scalar_per_access_tuple.At(src_i); + static_assert( + SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector"); + }, + Number{}); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i), + src_dim_access_order); + }, + Number{}); + + // make forward steps + const auto src_forward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto src_backward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -src_scalar_per_access_tuple.At(src_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_ford>{}( + [&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths_tuple[j] + + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_src_access_idx[i] + : ordered_src_access_lengths_tuple.At(src_i)[i] - + 1 - ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access_tuple.At(src_i); + }(); + + constexpr auto src_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_descs.At(src_i), src_coords_.At(src_i)); + + using src_vector_type = vector_type_maker_t, + SrcsScalarPerVector::At(src_i)>; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = + src_vector_type{src_bufs.At(src_i).template Get( + src_coords_.At(src_i).GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .At(src_i) + .template SetAsType( + src_data_idx_seq, + src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < + ordered_src_access_lengths_tuple.At(src_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == + ordered_src_access_lengths_tuple.At(src_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + } + }); + }); + }); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + // move src coordinate back to slice origin (or not) + if constexpr(SrcsResetCoordinateAfterRun::At(src_i)) + { + const auto src_reset_step = make_tensor_coordinate_step( + src_descs.At(src_i), GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step); + } + }); + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { + // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + // (it requires to add Elementwise support in transpose_vectors) + static_ford{}([&](auto idx) { + const auto src_data_refs = generate_tie( + [&](auto src_i) -> const auto& { + return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; + }, + Number{}); + + auto dst_data_refs = generate_tie( + [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, + Number{}); + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access_tuple = generate_tuple( + [&](auto dst_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); }, + Number{}); + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { + return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i), + dst_dim_access_order); + }, + Number{}); + + // make forward steps + const auto dst_forward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -dst_scalar_per_access_tuple.At(dst_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_ford>{}( + [&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] + + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths_tuple.At(dst_i)[i] - + 1 - ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access_tuple.At(dst_i); + }(); + + constexpr auto dst_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + dst_descs.At(dst_i), dst_coords_.At(dst_i)); + + using dst_vector_type = vector_type_maker_t, + DstsScalarPerVector::At(dst_i)>; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_tuple_.At(dst_i).template GetAsType( + dst_data_idx_seq)}; + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(dst_i.value)); + + // copy data from dst_vector_container to dst_buf + dst_bufs.At(dst_i).template Update( + dst_coords_.At(dst_i).GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < + ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == + ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + } + }); + }); + }); + + // move dst coordinate back to slice origin (or not) + static_for<0, nDst, 1>{}([&](auto dst_i) { + if constexpr(DstsResetCoordinateAfterRun::At(dst_i)) + { + const auto dst_reset_step = make_tensor_coordinate_step( + dst_descs.At(dst_i), GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step); + } + }); + } + + template + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + template + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access.At(dst_i); + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + const Index& src_slice_origin_step_idx) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcsResetCoordinateAfterRun::At(src_i) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step); + }); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + const Index& dst_slice_origin_step_idx) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstsResetCoordinateAfterRun::At(dst_i) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step); + }); + } + + template + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(src_access_lengths), + Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + template + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(dst_access_lengths), + Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto MakeSrcThreadScratchTuple() + { + return generate_tuple( + [&](auto src_i) { + constexpr auto src_thread_scratch_desc = + decltype(GetSrcThreadScratchDescriptor()){}; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer, + SrcsScalarPerVector::At(src_i), + decltype(src_thread_scratch_desc), + true>; + return SrcThreadScratch{}; + }, + Number{}); + } + + __device__ static constexpr auto MakeDstThreadScratchTuple() + { + return generate_tuple( + [&](auto dst_i) { + constexpr auto dst_thread_scratch_desc = + decltype(GetDstThreadScratchDescriptor()){}; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer, + DstsScalarPerVector::At(dst_i), + decltype(dst_thread_scratch_desc), + true>; + return DstThreadScratch{}; + }, + Number{}); + } + + private: + using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple()); + using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple()); + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratchTuple dst_thread_scratch_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index bd01108b03ca82d4877e19806521efda8da8c6c3..40ebdeff08e8627fc4bb3e237812b483f13bc8d2 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 forward_sweep_(I0) = true; static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; + index_t tmp = 0; static_for<0, i, 1>{}([&](auto j) { tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 6ec9abc41708735011acda7fcd01e54cbec5c4a6..644877d3931f08ea46b017446140484de8746fa9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 // apply pointwise operation static_for<0, ScalarPerVector, 1>{}([&](auto i) { - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_vector_container.template AsType()[i]); // apply type convert - dst_vector_container.template AsType()(i) = type_convert(v); + dst_vector_container.template AsType()(i) = v; }); const bool is_dst_valid = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b277e4383f20cb0f6209b5af0e4c761d0eda6b2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,630 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r2 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea074144b66589bc861f6c57587bf9b658110046 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,648 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + if constexpr(SrcScalarPerVectors{}[i] == 1) + { + auto data_types = SrcDatas{}; + using DataType = remove_cvref_t; + const auto tmp = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + + static_for<0, SrcScalarPerVector, 1>{}( + [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); + } + else + { + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + } + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..409bb9f67465f6c45a607ceb603e0f033aa2840b --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_1x32x2 = 0, + dpp8_f16_2x16x2, + dpp8_f16_2x32x2, + dpp8_f16_4x16x2, + dpp8_f16_4x32x2, + dpp8_f16_8x16x2, + dpp8_f16_8x32x2, + dpp8_f16_16x16x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 1; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_1x32x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x32x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x32x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a436afd3951aa413136f8bdf49b6eb939610d405 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_smfmac.hpp" + +namespace ck { + +enum struct SmfmacInstr +{ + smfmac_f32_16x16x32f16 = 0, + smfmac_f32_32x32x16f16, + smfmac_f32_16x16x32bf16, + smfmac_f32_32x32x16bf16, +}; + +template +struct smfmac_type; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32f16::Run( + a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16f16::Run( + a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32bf16::Run( + a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16bf16::Run( + a, b, idx, reg_c); + } +}; + +template +struct SmfmacSelector +{ + template + static constexpr auto GetSmfmac(); + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32bf16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16bf16; + } + + static constexpr auto selected_smfmac = + smfmac_type()>{}; + + __host__ __device__ constexpr SmfmacSelector() + { + static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk == + selected_smfmac.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks == + selected_smfmac.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks || + selected_smfmac.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size == + selected_smfmac.m_per_blk * selected_smfmac.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_smfmac.is_k_reduction || + (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) * + selected_smfmac.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; } +}; + +template +struct SparseXdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks); + } + + __host__ __device__ constexpr SparseXdlopsGemm() + { + static_assert(NPerXdlops == 16 || NPerXdlops == 32, + "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(MPerXdlops == 16 || MPerXdlops == 32, + "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk, + smfmac_instr.num_input_blks, + smfmac_instr.group_size)), + make_pass_through_transform(smfmac_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / smfmac_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; } + + template + __device__ void + Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value, + "base base_type must be half or bfloat16!"); + + static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { + smfmac_instr.template run( + p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread); + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size; + + return CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto smfmac = + SmfmacSelector{}; + + static constexpr auto smfmac_instr = smfmac.selected_smfmac; + + static constexpr auto KPerXdlops = smfmac.GetKPerXdlops(); + static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 979f3567e987ed66fd45af5deeae7bb2ac3fc761..b435a2a1293c87cd70bee4130e5a15f60bec6bda 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -11,12 +11,17 @@ namespace ck { enum struct WmmaInstr { + // gfx11 wmma_f32_16x16x16_f16 = 0, wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, wmma_i32_16x16x16_iu8, - wmma_i32_16x16x16_iu4 + wmma_i32_16x16x16_iu4, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, }; /* @@ -89,18 +94,19 @@ struct wmma_type{}; - // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // * Fixed on gfx11, Will be wave mode dependent for future architectures static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; // * num_acc_vgprs_per_wave alone M direction // * num_subgroups alone M direction static constexpr index_t num_acc_vgprs_per_wave = - m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; template @@ -129,6 +135,7 @@ struct wmma_type @@ -153,7 +160,6 @@ struct wmma_type struct wmma_type + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { if constexpr(wave_size == 32) { - intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); + intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); + intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); } } }; - template struct wmma_type::Run(a, b, reg_c); + intrin_wmma_bf16_16x16x16_bf16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); + intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); } } }; -#endif - template struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + // static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } + } +}; + template - static constexpr auto GetWmma() + constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_f16; +#endif } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_bf16; +#endif } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { return WmmaInstr::wmma_f16_16x16x16_f16; } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { return WmmaInstr::wmma_bf16_16x16x16_bf16; } template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#else return WmmaInstr::wmma_i32_16x16x16_iu8; +#endif } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> - static constexpr auto GetWmma() + constexpr auto GetWmma() { return WmmaInstr::wmma_i32_16x16x16_iu4; } @@ -346,7 +476,7 @@ struct WmmaSelector static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * - selected_wmma.acc_data_size == + selected_wmma.acc_data_size * selected_wmma.acc_pack_number == selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, "WRONG! Invalid Number of Accumulator Register"); } @@ -358,7 +488,8 @@ template + bool TransposeC = false, + bool AssemblyBackend = false> struct WmmaGemm { static constexpr auto I0 = Number<0>{}; @@ -369,14 +500,14 @@ struct WmmaGemm static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - using CIndex4D = MultiIndex<4>; + using CIndex3D = MultiIndex<3>; __host__ __device__ constexpr WmmaGemm() { static_assert(NPerWmma == 16 && MPerWmma == 16, "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); - static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); + static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma"); } // WMMA output supporting C = A * B @@ -421,9 +552,49 @@ struct WmmaGemm Sequence<5>{})); } + // Transposed WMMA Output C' = B' * A' + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_pass_through_transform(Number{}), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_unmerge_transform(make_tuple(Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + } + __device__ static constexpr index_t GetRegSizePerWmma() { - return wmma_instr.num_acc_vgprs_per_wave; + return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number; } __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } @@ -449,20 +620,25 @@ struct WmmaGemm , "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "(int8, int32) or (int4, int32)!"); - if constexpr(!TransposeC) - { - wmma_instr.template run(p_a_wave, p_b_wave, p_c_thread); - } - else - { - wmma_instr.template run(p_b_wave, p_a_wave, p_c_thread); - } + static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + wmma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + wmma_instr.template run(p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); } __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } __device__ static auto GetSubGroupId() { + static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == + wmma_instr.wave_size, + ""); return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; } @@ -477,12 +653,20 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - return GetSwizzledLaneIdLow(); +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else + return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); +#endif } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { +#ifdef __gfx12__ return GetLaneIdUnderSubGroup(); +#else + return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); +#endif } __device__ static CIndex GetBeginOfThreadBlk() @@ -493,6 +677,14 @@ struct WmmaGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } + __device__ static CIndex3D GetBeginOfThreadBlk3D() + { + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId(); + + return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0}; + } + static constexpr auto wmma = WmmaSelector{}; static constexpr auto wmma_instr = wmma.selected_wmma; @@ -500,7 +692,10 @@ struct WmmaGemm __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() { - return make_tuple(I1, I1, Number{}); + return make_tuple(I1, + I1, + Number{}, + Number{}); } }; diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 814969ef42baa2cc8e1f4188f61b3d50d74ea265..24fac91e22a01c7fbf7744ab48e7256d7e2ef900 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -31,7 +31,13 @@ enum struct MfmaInstr mfma_i32_16x16x32i8, mfma_f64_16x16x4f64, mfma_f32_32x32x16f8f8, - mfma_f32_16x16x32f8f8 + mfma_f32_16x16x32f8f8, + mfma_f32_32x32x16bf8bf8, + mfma_f32_16x16x32bf8bf8, + mfma_f32_32x32x16f8bf8, + mfma_f32_16x16x32f8bf8, + mfma_f32_32x32x16bf8f8, + mfma_f32_16x16x32bf8f8 }; template @@ -500,104 +506,242 @@ struct mfma_type } }; -template +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); + } +}; + +template struct MfmaSelector { - template + template static constexpr auto GetMfma(); template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f64_16x16x4f64; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x1xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x2xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x4xf32; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x8f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x16f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_4x4x4f16; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_32x32x8bf16_1k; @@ -607,7 +751,7 @@ struct MfmaSelector } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -618,41 +762,78 @@ struct MfmaSelector #if defined(CK_USE_AMD_MFMA_GFX940) template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x16i8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_16x16x32i8; } #else template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_32x32x8i8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_i32_16x16x16i8; } #endif template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8f8; } template <> - static constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } - static constexpr auto selected_mfma = mfma_type()>{}; + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8bf8; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8bf8; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8bf8; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8bf8; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8f8; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8f8; + } + + static constexpr auto selected_mfma = + mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -699,7 +880,8 @@ template + typename additional_type = base_type, + bool TransposeC = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -850,10 +1032,14 @@ struct XdlopsGemm template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "base base_type must be double, float, half, bfloat16, and int8_t!"); + static_assert( + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || + (is_same::value && is_same::value) || + (is_same::value && is_same::value), + "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) @@ -949,7 +1135,7 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp new file mode 100644 index 0000000000000000000000000000000000000000..56181d38c873833c10819449bd42f0a60f4ab3a4 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp @@ -0,0 +1,391 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +__host__ __device__ static auto +MakeGridDescriptorPair(const std::array& gs_ms_ns_lengths_vec, + const std::array& gs_ms_ns_strides_vec) +{ + // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + // { + // throw std::runtime_error("wrong! dimension must match input lengths"); + // } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + __host__ __device__ static auto MakeAGridDescriptorPair( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + __host__ __device__ static auto MakeAGridDescriptor_G_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + __host__ __device__ static auto MakeAGridDescriptor_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + const AGridDesc_M_K& a_grid_desc_m_k, + const WmmaK&, + const MRepeat&, + const MWaves&, + const MPerWmma&, + const AK1&) + { + const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock; + const auto K = a_grid_desc_m_k.GetLength(I1); + const auto AKWmma = K / WmmaK{}; + constexpr auto AKRow = 2; + constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{}; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform( + make_tuple(AKWmma, Number{}, Number{}, AK1{})), + make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B (alias of B0) + // + __host__ __device__ static auto MakeB0GridDescriptorPair( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeB0GridDescriptor_G_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + __host__ __device__ static auto MakeB0GridDescriptor_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + const BGridDesc_L_K& b_grid_desc_l_k, + const WmmaK&, + const LRepeat&, + const LWaves&, + const LPerWmma&, + const BK1&) + { + const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock; + const auto K = b_grid_desc_l_k.GetLength(I1); + const auto BKWmma = K / WmmaK{}; + constexpr auto BKRow = 2; + constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{}; + + return transform_tensor_descriptor( + b_grid_desc_l_k, + make_tuple(make_unmerge_transform( + make_tuple(BKWmma, Number{}, Number{}, BK1{})), + make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B1 + // + __host__ __device__ static auto MakeB1GridDescriptorPair( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + __host__ __device__ static auto MakeB1GridDescriptor_G_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + __host__ __device__ static auto MakeB1GridDescriptor_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + const BGridDesc_N_L& b_grid_desc_n_l, + const WmmaL&, + const NRepeat&, + const NWaves&, + const NPerWmma&, + const BL1&) + { + const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock; + const auto L = b_grid_desc_n_l.GetLength(I1); + const auto BLWmma = L / WmmaL{}; + constexpr auto BLRow = 2; + constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{}; + + return transform_tensor_descriptor( + b_grid_desc_n_l, + make_tuple(make_unmerge_transform( + make_tuple(BLWmma, Number{}, Number{}, BL1{})), + make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // C + // + __host__ __device__ static auto MakeCGridDescriptorPair( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeCGridDescriptor_G_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + __host__ __device__ static auto MakeCGridDescriptor_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index b1b203b43f381ef5779ac173c609ee893a70966b..2be0b66812434eb5127f5a97534ca4dc36b68273 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -164,6 +164,7 @@ template < index_t BK1, index_t GemmMPerBlock, index_t GemmNPerBlock, + index_t GemmKPerBlock, bool DoPadGemmM, bool DoPadGemmN> struct TransformConvBwdDataToGemm_v1 @@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - const index_t AK0 = - math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1); - if constexpr(NDimSpatial == 2) { // A: output tensor @@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1 const auto out_gemmk_gemmm_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( out_gemmk_gemmmraw_grid_desc, - make_tuple(AK1, GemmMPerBlock), + make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), @@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1 const auto out_gemmk_gemmm_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( out_gemmk_gemmmraw_grid_desc, - make_tuple(AK1, GemmMPerBlock), + make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), @@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - const index_t BK0 = - math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1); - // B weight tensor if constexpr(NDimSpatial == 2) { @@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_gemmk_gemmn_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( wei_gemmk_gemmnraw_grid_desc, - make_tuple(BK1, GemmNPerBlock), + make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), @@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto wei_gemmk_gemm_padded_grid_desc = + const auto wei_gemmk_gemmn_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( wei_gemmk_gemmnraw_grid_desc, - make_tuple(BK1, GemmNPerBlock), + make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemm_padded_grid_desc, - make_tuple( - make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))), + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c11bf845d05aba8d3cc843a354c2af6b7f2b1bd0 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -0,0 +1,714 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +template +struct TransformConvBwdWeightToGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& /* input_strides */, + const std::array& /* weights_strides */, + const std::array& /* output_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc290d56413ffb5c7214511c95ec97b6a8b0c617 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -0,0 +1,640 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +/** + * @brief Transform conv bwd weight to gemm v2 + * + * This version does following things: + * 1. Merge KBatch with K0 to align descriptor with universal gemm + * 2. Merge Batch with M and N dimension. It allows to increase compute in + * case of small M and N. It also allows to vector load and store in case of + * K = 1, C = 1 and NHWGC layout. + */ +template +struct TransformConvBwdWeightToGemmV2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[4]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[5]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * Y * NumGroupsToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index cee3d2825b13facc912cdc8de21774ce793e9a2c..b91b12ad52380a82ac6213cea27016700aca1461 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,346 +14,465 @@ namespace ck { namespace tensor_operation { -template +template struct TransformConvFwdToGemm { + private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } - const index_t Wi = a_g_n_c_wis_lengths[3]; + return acc; + } - const index_t Wo = c_g_n_k_wos_lengths[3]; + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); - const index_t ConvStrideW = conv_filter_strides[0]; + const IndexType N = a_g_n_c_wis_lengths[I1]; - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + if(element_space_size > TwoGB) { - const index_t NWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); - - return in_gemmm_gemmk_desc; + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Not possible to support even after split N. + // Too large tensor. + return N; + } } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) + else { - const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + // Split N is not needed. + return N; + } + } - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + public: + __host__ __device__ constexpr TransformConvFwdToGemm() {} + + template + __host__ __device__ + TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) + : N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + DiStride_{static_cast(transform_conv_fwd_to_gemm_base.DiStride_)}, + HiStride_{static_cast(transform_conv_fwd_to_gemm_base.HiStride_)}, + WiStride_{static_cast(transform_conv_fwd_to_gemm_base.WiStride_)}, + DoStride_{static_cast(transform_conv_fwd_to_gemm_base.DoStride_)}, + HoStride_{static_cast(transform_conv_fwd_to_gemm_base.HoStride_)}, + WoStride_{static_cast(transform_conv_fwd_to_gemm_base.WoStride_)}, + XStride_{static_cast(transform_conv_fwd_to_gemm_base.XStride_)}, + CStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.CStrideTensorA_)}, + CStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.CStrideTensorB_)}, + KStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.KStrideTensorB_)}, + KStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.KStrideTensorC_)}, + NStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.NStrideTensorA_)}, + NStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.NStrideTensorC_)}, + GStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorA_)}, + GStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorB_)}, + GStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorC_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{I1}, + HiStride_{I1}, + WiStride_{a_g_n_c_wis_strides[I3]}, + DoStride_{I1}, + HoStride_{I1}, + WoStride_{c_g_n_k_wos_strides[I3]}, + XStride_{b_g_k_c_xs_strides[I3]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + NStrideTensorC_{c_g_n_k_wos_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); - return in_gemmm_gemmk_desc; + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); } else { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + N_ = c_g_n_k_wos_lengths[I1]; } } - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{I1}, + HiStride_{a_g_n_c_wis_strides[I3]}, + WiStride_{a_g_n_c_wis_strides[I4]}, + DoStride_{I1}, + HoStride_{c_g_n_k_wos_strides[I3]}, + WoStride_{c_g_n_k_wos_strides[I4]}, + XStride_{b_g_k_c_xs_strides[I4]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + NStrideTensorC_{c_g_n_k_wos_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = c_g_n_k_wos_lengths[3]; - const index_t Wo = c_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + if constexpr(SplitN) { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); - - return in_gemmm_gemmk_desc; + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) + else { - const auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + N_ = c_g_n_k_wos_lengths[I1]; + } + } - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{a_g_n_c_wis_strides[I3]}, + HiStride_{a_g_n_c_wis_strides[I4]}, + WiStride_{a_g_n_c_wis_strides[I5]}, + DoStride_{c_g_n_k_wos_strides[I3]}, + HoStride_{c_g_n_k_wos_strides[I4]}, + WoStride_{c_g_n_k_wos_strides[I5]}, + XStride_{b_g_k_c_xs_strides[I5]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + NStrideTensorC_{c_g_n_k_wos_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); - return in_gemmm_gemmk_desc; + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); } else { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + N_ = c_g_n_k_wos_lengths[I1]; } } - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + __host__ bool AreDescriptorsSmallerThan2GB() const { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; + constexpr long_index_t TwoGB = (long_index_t{1} << 31); - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; - const index_t Do = c_g_n_k_wos_lengths[3]; - const index_t Ho = c_g_n_k_wos_lengths[4]; - const index_t Wo = c_g_n_k_wos_lengths[5]; + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) { - const index_t NDoHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); - - return in_gemmm_gemmk_desc; + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) + else if(is_possible_to_split_h) { - const auto in_n_di_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; - return in_gemmm_gemmk_desc; + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; } - else + else if(is_possible_to_split_w) { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; - const auto in_n_di_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); } // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as @@ -361,391 +480,764 @@ struct TransformConvFwdToGemm template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + __host__ __device__ auto MakeADescriptor_M_K() const { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Wi = a_g_n_c_wis_lengths[3]; - - const index_t Wo = c_g_n_k_wos_lengths[3]; - - const index_t ConvStrideW = conv_filter_strides[0]; - if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = c_g_n_k_wos_lengths[3]; - const index_t Wo = c_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + __host__ __device__ auto MakeADescriptor_M_K() const + { if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = c_g_n_k_wos_lengths[3]; - const index_t Ho = c_g_n_k_wos_lengths[4]; - const index_t Wo = c_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; + __host__ __device__ auto MakeADescriptor_M_K() const + { if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { - const index_t NDoHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -754,20 +1246,53 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto - MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */) + __host__ __device__ auto MakeBDescriptor_N_K() const { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = ck::accumulate_n( - b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto wei_gemmn_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); - - return wei_gemmn_gemmk_desc; + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + using FilterSizeNumType = + std::conditional_t, + std::conditional_t, Number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, ZYX_ * C_), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(ZYX_ * C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } } template < @@ -779,25 +1304,14 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + __host__ __device__ auto MakeBDescriptor_N_K() const { - const index_t K = b_g_k_c_xs_lengths[1]; - const index_t C = b_g_k_c_xs_lengths[2]; - - const index_t YX = ck::accumulate_n( - b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const index_t KStride = b_g_k_c_xs_strides[1]; - const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; - const auto CStride = I1; - const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( - make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + make_tuple(K_, ZYX_, C_), make_tuple(KStrideTensorB_, XStride_, CStrideTensorB_)); const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( wei_k_yx_c_desc, - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(make_pass_through_transform(K_), make_merge_transform(make_tuple(ZYX_, C_))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -805,74 +1319,255 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v, + index_t NDimSp = NDimSpatial, + + typename std::enable_if), bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + __host__ __device__ auto MakeCDescriptor_M_N() const { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; - - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } - const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); + template ), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); } - template < - typename CLayout, - typename std::enable_if || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> - static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) - { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; + template ), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + const IndexType NDoHoWo = N_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple( + NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + const IndexType NDoHoWo = N_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } - // for output bias template || - is_same_v, - bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + index_t NDimSp = NDimSpatial, + typename std::enable_if< + NDimSp == 3 && (is_same_v || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + DoStride_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + IndexType N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType DiStride_, HiStride_, WiStride_; + IndexType DoStride_, HoStride_, WoStride_; + IndexType XStride_; + IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_; + IndexType NStrideTensorA_, NStrideTensorC_; + IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; +}; - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); +// wrapper class to call member functions on TransformConvToGemm struct at runtime +// TODO: figure out aq way to properly pass in layout as an argument +struct TransformConv +{ + TransformConv() {} - return out_gemmm_gemmn_desc; + template + auto + transform_func(TransformConvFwdToGemm conv_fwd_to_gemm) + { + if(NDimSpatial == 2) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); + } + else if(NDimSpatial == 3) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); + } + else if(NDimSpatial == 1) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); + } } }; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8100b0bdbd0a24d25d7fcb829b62a1c355cc0b9a --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" + +namespace ck { +namespace tensor_operation { + +template +struct TransformConvNGCHWToNHWGC +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + template ::type = false> + static auto MakeNGCHWTransposeDesc(std::array g_n_c_wis_lengths, + std::array g_n_c_wis_strides) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t& N = g_n_c_wis_lengths[I1]; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Wi = g_n_c_wis_lengths[I3]; + + const index_t& GStride = g_n_c_wis_strides[I0]; + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t& CStride = g_n_c_wis_strides[I2]; + const index_t& WiStride = g_n_c_wis_strides[I3]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto MakeNHWGCTransposeDesc(std::array g_n_c_wis_lengths, + std::array g_n_c_wis_strides) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t& N = g_n_c_wis_lengths[I1]; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Wi = g_n_c_wis_lengths[I3]; + + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t WiStride = G * C; + const index_t GStride = C; + const index_t CStride = 1; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto MakeNGCHWTransposeDesc(std::array g_n_c_wis_lengths, + std::array g_n_c_wis_strides) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t& N = g_n_c_wis_lengths[I1]; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Hi = g_n_c_wis_lengths[I3]; + const index_t& Wi = g_n_c_wis_lengths[I4]; + + const index_t& GStride = g_n_c_wis_strides[I0]; + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t& CStride = g_n_c_wis_strides[I2]; + const index_t& HiStride = g_n_c_wis_strides[I3]; + const index_t& WiStride = g_n_c_wis_strides[I4]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto MakeNHWGCTransposeDesc(std::array g_n_c_wis_lengths, + std::array g_n_c_wis_strides) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t& N = g_n_c_wis_lengths[I1]; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Hi = g_n_c_wis_lengths[I3]; + const index_t& Wi = g_n_c_wis_lengths[I4]; + + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t HiStride = Wi * G * C; + const index_t WiStride = G * C; + const index_t GStride = C; + const index_t CStride = 1; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto MakeNGCHWTransposeDesc(std::array g_n_c_wis_lengths, + std::array g_n_c_wis_strides) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t& N = g_n_c_wis_lengths[I1]; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Di = g_n_c_wis_lengths[I3]; + const index_t& Hi = g_n_c_wis_lengths[I4]; + const index_t& Wi = g_n_c_wis_lengths[I5]; + + const index_t& GStride = g_n_c_wis_strides[I0]; + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t& CStride = g_n_c_wis_strides[I2]; + const index_t& DiStride = g_n_c_wis_strides[I3]; + const index_t& HiStride = g_n_c_wis_strides[I4]; + const index_t& WiStride = g_n_c_wis_strides[I5]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Di, Hi, Wi), + make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Di, Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto MakeNHWGCTransposeDesc(std::array g_n_c_wis_lengths, + std::array g_n_c_wis_strides) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t& N = g_n_c_wis_lengths[I1]; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Di = g_n_c_wis_lengths[I3]; + const index_t& Hi = g_n_c_wis_lengths[I4]; + const index_t& Wi = g_n_c_wis_lengths[I5]; + + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t DiStride = Hi * Wi * G * C; + const index_t HiStride = Wi * G * C; + const index_t WiStride = G * C; + const index_t GStride = C; + const index_t CStride = 1; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Di, Hi, Wi), + make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Di, Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + static auto TransposeStrides(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides) + { + if constexpr(device::is_NGCHW_GKYXC_NGKHW() || + device::is_NGCDHW_GKZYXC_NGKDHW()) + { + std::array g_n_c_wis_strides_transposed; + const auto G = g_n_c_wis_lengths[I0]; + const auto C = g_n_c_wis_lengths[I2]; + + g_n_c_wis_strides_transposed[I0] = C; + g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1]; + g_n_c_wis_strides_transposed[I2] = I1; + if constexpr(NDimSpatial == 2) + { + g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C; + g_n_c_wis_strides_transposed[I4] = G * C; + } + else if constexpr(NDimSpatial == 3) + { + g_n_c_wis_strides_transposed[I3] = + g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C; + g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C; + g_n_c_wis_strides_transposed[I5] = G * C; + } + return g_n_c_wis_strides_transposed; + } + else + { + // transpose not needed + return g_n_c_wis_strides; + } + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 897cb4f249f415a1ca3a6660063ebc946641bd47..d4ee5c886cd416e02de46a83213fc0bdc4e621ad 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -297,586 +297,297 @@ enum struct AmdBufferCoherenceEnum GLC = 1, SLC = 2, GLC_SLC = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; -template -__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +template +__device__ typename vector_type::type +amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { - static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), - "wrong! not implemented"); + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(N == 1) { - // use fp32 load to mimic fp64 load - if constexpr(N == 1) - { - const float2_t tmp = - llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 2) - { - const float4_t tmp = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 4) - { - const float4_t f32_0 = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - const float4_t f32_1 = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - vector_type tmp; - - tmp.AsType()(Number<0>{}) = bit_cast(f32_0); - tmp.AsType()(Number<1>{}) = bit_cast(f32_1); - - return tmp.AsType()(Number<0>{}); - } + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 2) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp; - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - return tmp.AsType()(Number<0>{}); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 4) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - // use fp32 load to mimic fp16 load - float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - return bit_cast(tmp); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - static_cast(coherence)); - return tmp.AsType()(Number<0>{}); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 8) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); -#else - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); -#endif - } - else if constexpr(N == 4) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); -#else - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + return bit_cast(tmp); + } + else if constexpr(N == 16) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 32) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + vector_type tmp; - return bit_cast(tmp); -#endif - } - else if constexpr(N == 8) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - static_cast(coherence)); - - return tmp.AsType()(Number<0>{}); -#else - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; - return bit_cast(tmp); -#endif - } - else if constexpr(N == 16) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - static_cast(coherence)); - - tmp.AsType()(Number<2>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(int8_t), - static_cast(coherence)); - - tmp.AsType()(Number<3>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(int8_t), - static_cast(coherence)); - - return tmp.AsType()(Number<0>{}); -#else - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); - return bit_cast(tmp); -#endif - } + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + tmp.AsType()(Number<2>{}) = tmp2; + tmp.AsType()(Number<3>{}) = tmp3; + + return bit_cast(tmp); } } template -__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + using r_t = typename vector_type::type; + auto raw_data = amd_buffer_load_impl_raw( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + return bit_cast(raw_data); +} + +template +__device__ void +amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) { - // use fp32 store to mimic fp64 store - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 2) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - } + + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 4) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { -#if 0 - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 8) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(bhalf_t), - static_cast(coherence)); - } + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 16) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 32) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i8(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } - else if constexpr(N == 4) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 16) - { - llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + } + else if constexpr(N == 64) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); + } +} + +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + + amd_buffer_store_impl_raw(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); +} + +template +__device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, + T* addr) +{ + static_assert((is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, + tmp.template AsType()[i]); + }); } +#if defined(__gfx942__) + else if constexpr(is_same::value) + { + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, + tmp.template AsType()[i]); + }); + } +#endif } template @@ -1127,31 +838,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - if constexpr(is_same::value) - { - auto tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - return bit_cast(tmp); - } - else - { - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - } #else - if constexpr(is_same::value) - { - auto tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? bit_cast(tmp) : vector_t(0); - } - else - { - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? tmp : vector_t(0); - } + + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? tmp : vector_t(0); #endif } @@ -1209,34 +903,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - if constexpr(is_same::value) - { - auto tmp = - bit_cast::type::type>(src_thread_data); - amd_buffer_store_impl( - tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); - } - else - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); - } + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { - if constexpr(is_same::value) - { - auto tmp = bit_cast::type::type>( - src_thread_data); - amd_buffer_store_impl( - tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } - else - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } @@ -1262,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; + if constexpr(is_same::value) + { + if(dst_thread_element_valid) + { + amd_global_atomic_add_impl( + src_thread_data, p_dst_wave + dst_thread_element_offset); + } + } + else + { #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if(dst_thread_element_valid) - { amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } #endif + } } // buffer_atomic_max requires: @@ -1311,4 +995,52 @@ amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thr #endif } +// Direct loads from global to LDS. +__device__ void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +template +__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, + const index_t global_offset, + T* lds_base_ptr, + const index_t lds_offset, + const bool is_valid, + const index_t src_element_space_size) +{ + // Direct loads require that each thread reads and writes exactly a single DWORD. + constexpr auto dword_bytes = 4; + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; + static_assert(bytes_per_thread == dword_bytes); + + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); + const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); + const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; + +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource) + : "memory"); +#else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); + + llvm_amdgcn_raw_buffer_load_lds( + src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); +#endif +} + } // namespace ck diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp index 8d6c7eede9d9fe38efd40ab12d6c401c999b6d5b..a28292dade3dcad8ddcd5bd3c0cd119aeb4b4253 100644 --- a/include/ck/utility/amd_gemm_dpp.hpp +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -5,17 +5,63 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/math.hpp" -#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/inner_product_dpp8.hpp" namespace ck { namespace dpp8 { -/// Number of lanes that can share data using DPP8 modifiers. -constexpr index_t lane_group_size = 8; +template +struct dpp_datatypes; -__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } -__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; } // namespace dpp8 diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 43baa817d3672c0b2eed6059a88351aeb19a5209..5dc67a5aded4af289d0240394f720af62e699eb4 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 "0"(c0), "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } @@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } -// Ranged input operand -__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) -{ -#if defined(__gfx11__) - asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); -#else - ignore = a; - ignore = b; - ignore = c; -#endif -} - } // namespace ck #endif diff --git a/include/ck/utility/amd_lds.hpp b/include/ck/utility/amd_lds.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c218fded96788214688ed28fd22993dfa2f5dc66 --- /dev/null +++ b/include/ck/utility/amd_lds.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +namespace lds_utils { + +/** \brief Allocate a given number of buffers in LDS and return them as a tuple. + * + * \tparam DataType Data type of elements to be stored in LDS. + * \tparam NumBuffers Number of buffers to be allocated. + * \param lds_ptr Address of the beginning of LDS space. + * \param num_elems_per_buffer Number of elements to allocate per single buffer. + * \param start_offset_elems Number of elements to move from the start of LDS for the allocation of + * the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of + * elements. \return Tuple of dynamic buffers representing memory allocated in LDS. + */ +template +__device__ static auto AllocateLdsBuffers(void* lds_ptr, + int32_t num_elems_per_buffer, + int32_t start_offset_elems, + int32_t lds_alignment) +{ + const DataType* lds_start = static_cast(lds_ptr) + start_offset_elems; + const int32_t single_buffer_offset = + math::integer_least_multiple(num_elems_per_buffer, lds_alignment); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer(lds_start + local_offset, + num_elems_per_buffer); + }, + Number{}); +} + +} // namespace lds_utils +} // namespace ck diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b6b094ff23aded21c1a9b48f6b40192f72f9e1d --- /dev/null +++ b/include/ck/utility/amd_smfmac.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#pragma once + +namespace ck { + +template +struct intrin_smfmac_f32_16x16x32f16; + +// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse +// indices from reg_idx +template <> +struct intrin_smfmac_f32_16x16x32f16<16, 16> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_16x16x32bf16; + +template <> +struct intrin_smfmac_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_32x32x16f16; + +template <> +struct intrin_smfmac_f32_32x32x16f16<32, 32> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_32x32x16bf16; + +template <> +struct intrin_smfmac_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 59c38bf3afd9759fcb626c58723057ede9e90bb4..1647e9ae2d4f3ababe258360d47f9c7d0ebbb06e 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -94,14 +94,36 @@ using get_carrier_t = typename get_carrier::type; } // namespace detail +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + __device__ inline int32_t amd_wave_read_first_lane(int32_t value) { return __builtin_amdgcn_readfirstlane(value); } -template ::value && - ck::is_trivially_copyable::value>> +__device__ inline int64_t amd_wave_read_first_lane(int64_t value) +{ + constexpr unsigned object_size = sizeof(int64_t); + constexpr unsigned second_part_offset = object_size / 2; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) std::byte to_obj[object_size]; + + using Sgpr = uint32_t; + + *reinterpret_cast(to_obj) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj)); + *reinterpret_cast(to_obj + second_part_offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + second_part_offset)); + + return *reinterpret_cast(to_obj); +} + +template < + typename Object, + typename = std::enable_if_t && std::is_trivially_copyable_v>> __device__ auto amd_wave_read_first_lane(const Object& obj) { using Size = unsigned; diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index dd7f0b770a11470037b1d33164d830cb9819e8d1..aa519fb2be7af0130d9379f43a7ab258c8cad5e2 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -9,6 +9,15 @@ // TODO: Add arch limitation namespace ck { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx11_generic__) +#define __gfx11__ +#endif + +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) +#define __gfx12__ +#endif + /********************************WAVE32 MODE***********************************************/ // src: fp16, dst: fp32 @@ -25,7 +34,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> // delete them. // amd_assembly_wmma_f32_16x16x16_f16_w32( // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else @@ -46,7 +55,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); @@ -71,7 +80,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else @@ -95,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) || defined(__gfx12__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); @@ -117,7 +126,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( neg_a, @@ -145,7 +154,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else @@ -166,7 +175,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); @@ -191,7 +200,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else @@ -215,7 +224,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); @@ -237,7 +246,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( neg_a, @@ -254,5 +263,83 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; +// gfx12 +/********************************WAVE32 MODE***********************************************/ + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ea7755036fc64b0af6aa340e062e1277c485e785..a955279bc849c7958c61b789c58d2445c5fd8894 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,12 +1,13 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_AMD_XDLOPS_HPP -#define CK_AMD_XDLOPS_HPP - -#include "data_type.hpp" +#pragma once namespace ck { +// Define the common macro for gfx94x models +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif // fp32 template @@ -326,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16> __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) { reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - 0, - 0, - 0); + __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); } }; @@ -344,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> template __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -364,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> template __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(reg_a), @@ -396,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> template __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(reg_a), bit_cast(reg_b), @@ -417,5 +418,194 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> #endif } }; -} // namespace ck + +template +struct intrin_mfma_f32_32x32x16bf8bf8; + +template <> +struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8bf8; + +template <> +struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16f8bf8; + +template <> +struct intrin_mfma_f32_32x32x16f8bf8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8bf8; + +template <> +struct intrin_mfma_f32_16x16x32f8bf8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); #endif + } +}; + +template +struct intrin_mfma_f32_32x32x16bf8f8; + +template <> +struct intrin_mfma_f32_32x32x16bf8f8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8f8; + +template <> +struct intrin_mfma_f32_16x16x32bf8f8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +} // namespace ck diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 2004afef237600dc940cbaaa9dc122d1b539e05b..4551dfdce277b95225074a7d97262cf9f3f5d586 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -36,6 +36,8 @@ struct Array return *this; } + __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } + __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } }; // empty Array diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..902195e2f89395ce17b50be4cb3a5f85700299e0 --- /dev/null +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct BlockGemmPipelineScheduler +{ + Intrawave, + Interwave, +}; + +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, + Even, + + // Long prefetch pipeline, up to 8 + One, + Two, + Three, + Four, + Five, + Six, + Seven, + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, +}; +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +} // namespace ck diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8f896a4fca7075c4cce6f04e4685bcbe4305de58..7e24d601df2c6be00c2a5053c72ffd0d2416154e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,13 +24,28 @@ using std::byte; using bhalf_t = ushort; using half_t = _Float16; -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -using int4_t = _BitInt(4); -#endif -using f8_t = uint8_t; +using int4_t = _BitInt(4); +using f8_t = _BitInt(8); +using bf8_t = unsigned _BitInt(8); + +inline constexpr auto next_pow2(uint32_t x) +{ + // Precondition: x > 1. + return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; +} + +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool +template +inline constexpr bool is_native_type() +{ + return is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value; +} // vector_type -template +template struct vector_type; // Caution: DO NOT REMOVE @@ -149,6 +164,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = uint8_t; + static constexpr index_t vector_size = 1; +}; + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> struct scalar_type @@ -165,9 +187,22 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -// +template <> +struct scalar_type +{ + using type = bf8_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bool; + static constexpr index_t vector_size = 1; +}; + template -struct vector_type +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -185,7 +220,8 @@ struct vector_type template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value, "wrong!"); + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); return data_.d1x1_; } @@ -193,14 +229,16 @@ struct vector_type template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value, "wrong!"); + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); return data_.d1x1_; } }; +__device__ int static err = 0; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -221,7 +259,8 @@ struct vector_type template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value, "wrong!"); + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -231,12 +270,17 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value, "wrong!"); + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -246,11 +290,15 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -274,7 +322,7 @@ struct vector_type __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -288,13 +336,17 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } template __host__ __device__ constexpr auto& AsType() { static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -308,11 +360,15 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -339,7 +395,7 @@ struct vector_type { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -357,6 +413,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } template @@ -364,7 +424,7 @@ struct vector_type { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -382,11 +442,15 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -416,7 +480,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -438,6 +502,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } template @@ -446,7 +514,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -468,11 +536,15 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -504,7 +576,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -530,6 +602,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } template @@ -538,7 +614,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -564,11 +640,15 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -603,7 +683,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -633,6 +713,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } template @@ -642,7 +726,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -672,11 +756,15 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -713,7 +801,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -747,6 +835,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } template @@ -756,7 +848,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -790,11 +882,15 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -833,7 +929,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -871,6 +967,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } template @@ -880,7 +980,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -918,152 +1018,747 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } }; -using int64_t = long; - -// fp64 -using double2_t = typename vector_type::type; -using double4_t = typename vector_type::type; - -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; - -// fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; - -// bfp16 -using bhalf2_t = typename vector_type::type; -using bhalf4_t = typename vector_type::type; -using bhalf8_t = typename vector_type::type; -using bhalf16_t = typename vector_type::type; -using bhalf32_t = typename vector_type::type; -using bhalf64_t = typename vector_type::type; - -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; +template +struct non_native_vector_base +{ + using type = non_native_vector_base; -// i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; + __host__ __device__ non_native_vector_base() = default; + __host__ __device__ non_native_vector_base(const type&) = default; + __host__ __device__ non_native_vector_base(type&&) = default; + __host__ __device__ ~non_native_vector_base() = default; -// f8 -using f8x2_t = typename vector_type::type; -using f8x4_t = typename vector_type::type; -using f8x8_t = typename vector_type::type; -using f8x16_t = typename vector_type::type; -using f8x32_t = typename vector_type::type; -using f8x64_t = typename vector_type::type; + T d[N]; +}; +// non-native vector_type implementation template -struct NumericLimits; - -template <> -struct NumericLimits +struct vector_type()>> { - __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } - - __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } - - __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + using d1_t = T; + using type = d1_t; - __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + union alignas(next_pow2(1 * sizeof(T))) + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + } data_; - __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } -}; + __host__ __device__ constexpr vector_type() : data_{type{}} {} -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); - __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + return data_.d1x1_; + } - __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); - __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } + return data_.d1x1_; + } }; -template <> -struct NumericLimits +template +struct vector_type()>> { - __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } - - __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + using d1_t = T; + using d2_t = non_native_vector_base; - __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + using type = d2_t; - __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + union alignas(next_pow2(2 * sizeof(T))) + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; - __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } -}; + __host__ __device__ constexpr vector_type() : data_{type{}} {} -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); - __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } - __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); - __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } }; -template <> -struct NumericLimits +template +struct vector_type()>> { - __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; - __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + using type = d4_t; - __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } -}; + union alignas(next_pow2(4 * sizeof(T))) + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; -template <> -struct NumericLimits -{ - static constexpr unsigned int binary_min = 0x00800000; - static constexpr unsigned int binary_max = 0x7F7FFFFF; - static constexpr unsigned int binary_lowest = 0xFF7FFFFF; - static constexpr unsigned int binary_qnan = 0xFFC00001; - static constexpr unsigned int binary_inf = 0x7F8000000; + __host__ __device__ constexpr vector_type() : data_{type{}} {} - __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + __host__ __device__ constexpr vector_type(type v) : data_{v} {} - __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); - __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } - __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + + using type = d8_t; + + union alignas(next_pow2(8 * sizeof(T))) + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + + using type = d16_t; + + union alignas(next_pow2(16 * sizeof(T))) + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + + using type = d32_t; + + union alignas(next_pow2(32 * sizeof(T))) + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + using d64_t = non_native_vector_base; + + using type = d64_t; + + union alignas(next_pow2(64 * sizeof(T))) + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +using int64_t = long; + +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using f8x2_t = typename vector_type::type; +using f8x4_t = typename vector_type::type; +using f8x8_t = typename vector_type::type; +using f8x16_t = typename vector_type::type; +using f8x32_t = typename vector_type::type; +using f8x64_t = typename vector_type::type; + +// bf8 +using bf8x2_t = typename vector_type::type; +using bf8x4_t = typename vector_type::type; +using bf8x8_t = typename vector_type::type; +using bf8x16_t = typename vector_type::type; +using bf8x32_t = typename vector_type::type; +using bf8x64_t = typename vector_type::type; + +// u8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + + __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + + __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + + __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + + __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + + __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned int binary_min = 0x00800000; + static constexpr unsigned int binary_max = 0x7F7FFFFF; + static constexpr unsigned int binary_lowest = 0xFF7FFFFF; + static constexpr unsigned int binary_qnan = 0xFFC00001; + static constexpr unsigned int binary_inf = 0x7F8000000; + + __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } }; @@ -1100,18 +1795,112 @@ struct NumericLimits template <> struct NumericLimits { + // negative zero nan mode with exp bias = 8 static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x77; // 0b01110111 - static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } - __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } - __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } + __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } - __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } + __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } }; +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } + + __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } + + __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } +}; + +template +struct NumericUtils +{ +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 8; // negative zero nan mode + // static constexpr int bias = 7; // ieee mode +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 16; // negative zero nan mode + // static constexpr int bias = 15; // ieee mode +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 802b145284a2f96c91c927b4c72b4d4f72b9166a..fc3c53c820b654b10b20b5ef818e49aef7b3a221 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP @@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) __syncthreads(); } +template +__device__ static bool is_thread_local_1d_id_idx() +{ + const auto tid = get_thread_local_1d_id(); + return ((tid == Ids) || ...); +} + } // namespace debug } // namespace ck diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 02d61f34ed5192480d09f4432eec30e65e7264b9..0dcc514a2f6548d6ca4ac5f8d8c89ee09775131c 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -140,13 +140,59 @@ struct DynamicBuffer } else if constexpr(Op == InMemoryDataOperationEnum::Add) { - auto tmp = this->template Get(i, is_valid_element); - this->template Set(i, is_valid_element, x + tmp); - // tmp += x; - // this->template Set(i, is_valid_element, tmp); + auto tmp = this->template Get(i, is_valid_element); + using scalar_t = typename scalar_type>::type; + // handle bfloat addition + if constexpr(is_same_v) + { + if constexpr(is_scalar_type::value) + { + // Scalar type + auto result = + type_convert(type_convert(x) + type_convert(tmp)); + this->template Set(i, is_valid_element, result); + } + else + { + // Vector type + constexpr auto vector_size = scalar_type>::vector_size; + const vector_type a_vector{tmp}; + const vector_type b_vector{x}; + static_for<0, vector_size, 1>{}([&](auto idx) { + auto result = type_convert( + type_convert(a_vector.template AsType()[idx]) + + type_convert(b_vector.template AsType()[idx])); + this->template Set(i + idx, is_valid_element, result); + }); + } + } + else + { + this->template Set(i, is_valid_element, x + tmp); + } } } + template + __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf, + index_t src_offset, + index_t dst_offset, + bool is_valid_element) const + { + // Copy data from global to LDS memory using direct loads. + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + amd_direct_load_global_to_lds(p_data_, + src_offset, + dst_buf.p_data_, + dst_offset, + is_valid_element, + element_space_size_); + } + template >::type, typename scalar_type>::type>::value, @@ -312,13 +358,15 @@ struct DynamicBuffer bool constexpr use_amd_buffer_addressing = is_same_v, int32_t> || is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #else bool constexpr use_amd_buffer_addressing = false; #endif diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6455402dcb331d91240ebe09d4d553f4d355f96e --- /dev/null +++ b/include/ck/utility/env.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace internal { +template +struct ParseEnvVal +{ +}; + +template <> +struct ParseEnvVal +{ + static bool parse_env_var_value(const char* vp) + { + std::string value_env_str{vp}; + + for(auto& c : value_env_str) + { + if(std::isalpha(c) != 0) + { + c = std::tolower(static_cast(c)); + } + } + + if(value_env_str == "disable" || value_env_str == "disabled" || value_env_str == "0" || + value_env_str == "no" || value_env_str == "off" || value_env_str == "false") + { + return false; + } + else if(value_env_str == "enable" || value_env_str == "enabled" || value_env_str == "1" || + value_env_str == "yes" || value_env_str == "on" || value_env_str == "true") + { + return true; + } + else + { + throw std::runtime_error("Invalid value for env variable"); + } + + return false; // shouldn't reach here + } +}; + +// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default). +// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string). +template <> +struct ParseEnvVal +{ + static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); } +}; + +template <> +struct ParseEnvVal +{ + static std::string parse_env_var_value(const char* vp) { return std::string{vp}; } +}; + +template +struct EnvVar +{ + private: + T value{}; + bool is_unset = true; + + public: + const T& GetValue() const { return value; } + + bool IsUnset() const { return is_unset; } + + void Unset() { is_unset = true; } + + void UpdateValue(const T& val) + { + is_unset = false; + value = val; + } + + explicit EnvVar(const char* const name, const T& def_val) + { + // NOLINTNEXTLINE (concurrency-mt-unsafe) + const char* vp = std::getenv(name); + if(vp != nullptr) // a value was provided + { + is_unset = false; + value = ParseEnvVal::parse_env_var_value(vp); + } + else // no value provided, use default value + { + value = def_val; + } + } +}; +} // end namespace internal + +// static inside function hides the variable and provides +// thread-safety/locking +// Used in global namespace +#define CK_DECLARE_ENV_VAR(name, type, default_val) \ + namespace ck::env { \ + struct name \ + { \ + static_assert(std::is_same_v, \ + "CK_DECLARE_ENV* must be used in the global namespace"); \ + using value_type = type; \ + static ck::internal::EnvVar& Ref() \ + { \ + static ck::internal::EnvVar var{#name, default_val}; \ + return var; \ + } \ + }; \ + } + +#define CK_DECLARE_ENV_VAR_BOOL(name) CK_DECLARE_ENV_VAR(name, bool, false) + +#define CK_DECLARE_ENV_VAR_UINT64(name) CK_DECLARE_ENV_VAR(name, uint64_t, 0) + +#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") + +#define CK_ENV(name) \ + ck::env::name {} + +template +inline const std::string& EnvGetString(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsEnabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsDisabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue(); +} + +template +inline uint64_t EnvValue(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsUnset(EnvVar) +{ + return EnvVar::Ref().IsUnset(); +} + +template +void EnvUnset(EnvVar) +{ + EnvVar::Ref().Unset(); +} + +/// updates the cached value of an environment variable +template +void UpdateEnvVar(EnvVar, const ValueType& val) +{ + static_assert(std::is_same_v); + EnvVar::Ref().UpdateValue(val); +} + +template +void UpdateEnvVar(EnvVar, const std::string_view& val) +{ + EnvVar::Ref().UpdateValue( + ck::internal::ParseEnvVal::parse_env_var_value(val.data())); +} + +} // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index ef4d16a8d09de7ee898ba6815efee98974c97a20..2533073225f0b2ffed36ec5c339f416bb3381cfa 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -16,59 +16,46 @@ enum class f8_rounding_mode stochastic }; +__host__ inline int clz(uint32_t x) { return __builtin_clz(x); } +__device__ inline int clz(uint32_t x) { return __clz(x); } + } // namespace ck namespace ck::utils { namespace { -template -__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; - // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + // original type exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; - int exponent; + int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr uint8_t nan_code = 0x80; - constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise - typedef typename ck::conditional::value, uint16_t, uint32_t>::type - T_bitwise; - T_bitwise x_bitwise = *(reinterpret_cast(&x)); + using T_bitwise = typename NumericUtils::bitwise_type; + T_bitwise x_bitwise = bit_cast(x); // unpack the input, depends on datatype - if constexpr(is_float) - { - head = x_bitwise & 0xFF800000; - mantissa = x_bitwise & 0x7FFFFF; - exponent = (head >> type_mant) & 0xFF; - sign = head >> (type_exp + type_mant); - } - else if constexpr(is_half) - { - head = x_bitwise & 0xFC00; - mantissa = x_bitwise & 0x3FF; - exponent = (head >> type_mant) & 0x1F; - sign = head >> (type_exp + type_mant); - } + head = x_bitwise & NumericUtils::head_mask; + mantissa = x_bitwise & NumericUtils::mant_mask; + exponent = (head >> in_mant) & NumericUtils::exp_mask; + sign = head >> (in_exp + in_mant); + bias = NumericUtils::bias; - uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); - uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; - constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); - constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); if constexpr(negative_zero_nan) { @@ -85,39 +72,103 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) if(x_bitwise == 0) return 0; - exponent -= exp_low_cutoff - 1; - if(exponent <= 0) - drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; - mantissa += 1 << type_mant; - // apply random number if needed - mantissa += (stoch ? rng : mantissa) & drop_mask; - if(mantissa >= (2 << type_mant)) - { - mantissa >>= 1; - exponent++; + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa } - mantissa >>= (type_mant - f8_mant); - // check negative exponent - if(exponent <= 0) + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) { - if(x_bitwise == 0) - return 0; - else + if((1 << in_mant) & mantissa) + { + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later + } + } + else + { + if((1 << (in_mant + 1)) & mantissa) { - // subnormal range; represented by a subnormal float8 (exponent 0) - // and involves loss of accuracy - mantissa >>= 1 - exponent; - exponent = 0; + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later } } - // above range: quantize to maximum possible float of the same sign - else if(exponent > max_exp) + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) { - if(clip) + if constexpr(clip) { - mantissa = (1 << f8_mant) - 1; - exponent = max_exp; + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; } else { @@ -126,125 +177,117 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) } // check if x is 0.0 or -0.0 - if(exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); - mantissa &= (1 << f8_mant) - 1; - return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; + if(out_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; } -template -__host__ __device__ T run_cast_from_f8(f8_t x) +template +__host__ __device__ Y run_cast_from_f8(X x) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr uint8_t nan_code = 0x80; - T fInf, fNegInf, fNaN, fNeg0; - if constexpr(is_half) - { - constexpr uint16_t ihInf = 0x7C00; - constexpr uint16_t ihNegInf = 0xFC00; - constexpr uint16_t ihNaN = 0x7C01; - constexpr uint16_t ihNeg0 = 0x8000; - fInf = *(reinterpret_cast(&ihInf)); - fNegInf = *(reinterpret_cast(&ihNegInf)); - fNaN = *(reinterpret_cast(&ihNaN)); - fNeg0 = *(reinterpret_cast(&ihNeg0)); - } - else if constexpr(is_float) - { - constexpr uint32_t ifInf = 0x7F800000; - constexpr uint32_t ifNegInf = 0xFF800000; - constexpr uint32_t ifNaN = 0x7F800001; - constexpr uint32_t ifNeg0 = 0x80000000; - fInf = *(reinterpret_cast(&ifInf)); - fNegInf = *(reinterpret_cast(&ifNegInf)); - fNaN = *(reinterpret_cast(&ifNaN)); - fNeg0 = *(reinterpret_cast(&ifNeg0)); - } + constexpr X nan_code = 0x80; + using T_bitwise = typename NumericUtils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; + constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; + constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; + constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; + + constexpr Y Inf = bit_cast(Inf_bitwise); + constexpr Y NegInf = bit_cast(NegInf_bitwise); + constexpr Y NaN = bit_cast(NaN_bitwise); + constexpr Y Neg0 = bit_cast(Neg0_bitwise); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); // unpack the input - uint32_t sign = x >> (f8_exp + f8_mant); - uint32_t mantissa = x & ((1 << f8_mant) - 1); - int exponent = (x & 0x7F) >> f8_mant; + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); - typename ck::conditional::value, uint16_t, uint32_t>::type retval; + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; if constexpr(negative_zero_nan) { if(x == nan_code) - return fNaN; + return NaN; } else { if(x == nan_code) - return fNeg0; - if(exponent == ((1 << f8_exp) - 1)) - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if constexpr((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && + !negative_zero_nan) + { + retval = x; + retval <<= 8; + return bit_cast(retval); } // subnormal input if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); + int sh = 1 + clz(mantissa) - (32 - in_mant); mantissa <<= sh; - mantissa &= ((1 << f8_mant) - 1); exponent += 1 - sh; + mantissa &= ((1 << in_mant) - 1); } exponent += exp_low_cutoff - 1; - mantissa <<= type_mant - f8_mant; + mantissa <<= out_mant - in_mant; // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) if(exponent <= 0) { - mantissa |= 1 << type_mant; + mantissa |= 1 << out_mant; mantissa >>= 1 - exponent; exponent = 0; } - retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; - return *(reinterpret_cast(&retval)); + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return bit_cast(retval); } } // namespace -template -__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y cast_to_f8(X x, uint32_t rng) { - // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "Only half and float can be casted to f8."); + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); - return run_cast_to_f8(x, rng); + return run_cast_to_f8(x, rng); } -template -__host__ __device__ T cast_from_f8(f8_t x) +template +__host__ __device__ Y cast_from_f8(X x) { // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; static_assert(is_half || is_float, "only half and float are supported."); - // check if x is 0.0 - if(x == 0) - return static_cast(0); - - return run_cast_from_f8(x); + return run_cast_from_f8(x); } } // namespace ck::utils diff --git a/include/ck/utility/flush_icache.hpp b/include/ck/utility/flush_icache.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7378ba5c26dc14421076099680cf408d87fc07a8 --- /dev/null +++ b/include/ck/utility/flush_icache.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +static __global__ void flush_icache() +{ + asm __volatile__("s_icache_inv \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" :: + :); +} +} // namespace ck diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index b58b2b33191e23db2bda3f9d15f647f6d4f85f58..65efaf388ae528feab9fcf06818e111244fa091a 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -72,6 +72,18 @@ inner_product(const float4_t& a, const float4_t& b, f c); } +template <> +__device__ void inner_product(const bhalf_t& a, const bhalf_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { @@ -180,6 +192,8 @@ inner_product(const int8x4_t& a, const int8x4_t& b, #else c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); #endif +#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11) + c = __builtin_amdgcn_sudot4(true, bit_cast(a), true, bit_cast(b), c, false); #else const vector_type a_vector{a}; const vector_type b_vector{b}; diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp index ccd7a4e6284fa07d7f08a8559c782ff8767d0bd6..f079e2ca6490676d6c0e83e9e74dba205a8d6583 100644 --- a/include/ck/utility/inner_product_dpp8.hpp +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once + #include "amd_gemm_dpp.hpp" #include "data_type.hpp" #include "type_convert.hpp" @@ -10,6 +11,9 @@ namespace ck { namespace dpp8 { +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + template __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a324a6c458b3f1b8bb8037ccfac76e5eadceee0 --- /dev/null +++ b/include/ck/utility/is_detected.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template